eagle.py 56.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
import ast
from dataclasses import replace
6
from importlib.util import find_spec
7

8
import numpy as np
9

10
11
12
import torch
import torch.nn as nn

13
from vllm.attention.backends.registry import AttentionBackendEnum
14
from vllm.config import (
15
    CompilationMode,
16
17
18
19
    CUDAGraphMode,
    VllmConfig,
    get_layers_from_vllm_config,
)
20
from vllm.distributed.parallel_state import get_pp_group
21
from vllm.forward_context import set_forward_context
22
from vllm.logger import init_logger
23
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
24
from vllm.model_executor.model_loader import get_model
25
from vllm.model_executor.models import supports_multimodal
26
from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache
27
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
28
from vllm.multimodal import MULTIMODAL_REGISTRY
29
from vllm.platforms import current_platform
30
from vllm.triton_utils import triton
31
from vllm.utils.platform_utils import is_pin_memory_available
32
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
33
34
35
36
from vllm.v1.attention.backends.tree_attn import (
    TreeAttentionMetadata,
    TreeAttentionMetadataBuilder,
)
37
from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata
38
39
40
41
from vllm.v1.attention.backends.utils import (
    AttentionMetadataBuilder,
    CommonAttentionMetadata,
)
42
from vllm.v1.kv_cache_interface import KVCacheConfig
43
from vllm.v1.sample.metadata import SamplingMetadata
44
from vllm.v1.sample.sampler import _SAMPLING_EPS
45
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
46
47
48
49
from vllm.v1.spec_decode.utils import (
    eagle_prepare_inputs_padded_kernel,
    eagle_prepare_next_token_padded_kernel,
)
50
from vllm.v1.utils import CpuGpuBuffer
Rémi Delacourt's avatar
Rémi Delacourt committed
51
from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
52
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
53

54
55
logger = init_logger(__name__)

56
57
PADDING_SLOT_ID = -1

58
59
60
61
62
63

class EagleProposer:
    def __init__(
        self,
        vllm_config: VllmConfig,
        device: torch.device,
Jiayi Yao's avatar
Jiayi Yao committed
64
        runner=None,
65
66
    ):
        self.vllm_config = vllm_config
67
        self.speculative_config = vllm_config.speculative_config
68
        assert self.speculative_config is not None
69
70
        self.draft_model_config = self.speculative_config.draft_model_config
        self.method = self.speculative_config.method
71

Jiayi Yao's avatar
Jiayi Yao committed
72
        self.runner = runner
73
        self.device = device
74
        self.dtype = vllm_config.model_config.dtype
75
        self.max_model_len = vllm_config.model_config.max_model_len
76
        self.block_size = vllm_config.cache_config.block_size
Rémi Delacourt's avatar
Rémi Delacourt committed
77
        self.dp_rank = vllm_config.parallel_config.data_parallel_rank
78
79
        self.num_speculative_tokens = self.speculative_config.num_speculative_tokens
        self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
80
        self.token_arange_np = np.arange(self.max_num_tokens)
81
82
83
84
        # We need to get the hidden size from the draft model config because
        # the draft model's hidden size can be different from the target model's
        # hidden size (e.g., Llama 3.3 70B).
        self.hidden_size = self.draft_model_config.get_hidden_size()
85
        self.inputs_embeds_size = self.draft_model_config.get_inputs_embeds_size()
86

87
88
89
        # Multi-modal data support
        self.mm_registry = MULTIMODAL_REGISTRY
        self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
90
91
            vllm_config.model_config
        )
92

93
94
        self.attn_metadata_builder: AttentionMetadataBuilder | None = None
        self.draft_indexer_metadata_builder: AttentionMetadataBuilder | None = None
95
96
        self.attn_layer_names: list[str] = []
        self.indexer_layer_names: list[str] = []
97
98
99
        self.eagle3_use_aux_hidden_state: bool = (
            self._get_eagle3_use_aux_hidden_state_from_config()
        )
100

101
102
        self.use_cuda_graph = False

103
104
105
        self.compilation_config = self.vllm_config.compilation_config
        if self.compilation_config.mode == CompilationMode.VLLM_COMPILE:
            cudagraph_mode = self.compilation_config.cudagraph_mode
106
107
108
109
110
111
112
113
114
115
116
117
118
            if cudagraph_mode != CUDAGraphMode.NONE and not cudagraph_mode.has_mode(
                CUDAGraphMode.PIECEWISE
            ):
                logger.warning(
                    "Currently the eagle proposer only supports cudagraph_mode "
                    "PIECEWISE, if you want the drafter to use cuda graphs, "
                    "please set compilation_config.cudagraph_mode to PIECEWISE "
                    "or FULL_AND_PIECEWISE"
                )
            self.use_cuda_graph = (
                cudagraph_mode.has_mode(CUDAGraphMode.PIECEWISE)
                and not self.speculative_config.enforce_eager
            )
119
120

        # persistent buffers for cuda graph
121
122
123
        self.input_ids = torch.zeros(
            self.max_num_tokens, dtype=torch.int32, device=device
        )
124
125
        self.uses_mrope = self.vllm_config.model_config.uses_mrope
        if self.uses_mrope:
126
127
128
129
130
131
132
133
134
135
            # NOTE: `mrope_positions` is implemented with one additional dummy
            # position on purpose to make it non-contiguous so that it can work
            # with torch compile.
            # See detailed explanation in https://github.com/vllm-project/vllm/pull/12128#discussion_r1926431923

            # NOTE: When M-RoPE is enabled, position ids are 3D regardless of
            # the modality of inputs. For text-only inputs, each dimension has
            # identical position IDs, making M-RoPE functionally equivalent to
            # 1D-RoPE.
            # See page 5 of https://arxiv.org/abs/2409.12191
136
            self.mrope_positions = torch.zeros(
137
                (3, self.max_num_tokens + 1), dtype=torch.int64, device=device
138
            )
139
140
        else:
            # RoPE need (max_num_tokens,)
141
142
143
            self.positions = torch.zeros(
                self.max_num_tokens, dtype=torch.int64, device=device
            )
144
        self.hidden_states = torch.zeros(
145
146
            (self.max_num_tokens, self.hidden_size), dtype=self.dtype, device=device
        )
147

148
149
        # We need +1 here because the arange is used to set query_start_loc,
        # which has one more element than batch_size.
150
        max_batch_size = vllm_config.scheduler_config.max_num_seqs
151
        max_num_slots_for_arange = max(max_batch_size + 1, self.max_num_tokens)
152
153
154
        self.arange = torch.arange(
            max_num_slots_for_arange, device=device, dtype=torch.int32
        )
155

156
        self.inputs_embeds = torch.zeros(
157
            (self.max_num_tokens, self.inputs_embeds_size),
158
            dtype=self.dtype,
159
            device=device,
160
        )
161

162
163
164
165
166
        self.backup_next_token_ids = CpuGpuBuffer(
            max_batch_size,
            dtype=torch.int32,
            pin_memory=is_pin_memory_available(),
            device=device,
167
168
            with_numpy=True,
        )
169

170
        # Determine allowed attention backends once during initialization.
171
        self.allowed_attn_types: tuple | None = None
172
173
        if current_platform.is_rocm():
            rocm_types = [TritonAttentionMetadata, FlashAttentionMetadata]
174
175
176
177
            # ROCM_AITER_FA is an optional backend
            if find_spec(
                AttentionBackendEnum.ROCM_AITER_FA.get_path(include_classname=False)
            ):
178
                from vllm.v1.attention.backends.rocm_aiter_fa import (
179
180
181
                    AiterFlashAttentionMetadata,
                )

182
183
184
                rocm_types.append(AiterFlashAttentionMetadata)
            self.allowed_attn_types = tuple(rocm_types)

185
186
        # Parse the speculative token tree.
        spec_token_tree = self.speculative_config.speculative_token_tree
187
        self.tree_choices: list[tuple[int, ...]] = ast.literal_eval(spec_token_tree)
188
189
190
191
192
193
194
195
        tree_depth = len(self.tree_choices[-1])
        # Precompute per-level properties of the tree.
        num_drafts_per_level = [0] * tree_depth
        for node in self.tree_choices:
            num_drafts_per_level[len(node) - 1] += 1
        self.cu_drafts_per_level = [num_drafts_per_level[0]]
        self.child_drafts_per_level = [num_drafts_per_level[0]]
        for level in range(1, tree_depth):
196
197
198
199
200
201
            self.cu_drafts_per_level.append(
                self.cu_drafts_per_level[-1] + num_drafts_per_level[level]
            )
            self.child_drafts_per_level.append(
                num_drafts_per_level[level] // num_drafts_per_level[level - 1]
            )
202
203
204
205
206
207
208
        # Precompute draft position offsets in flattened tree.
        self.tree_draft_pos_offsets = torch.arange(
            1,
            len(self.tree_choices) + 1,
            device=device,
            dtype=torch.int32,
        ).repeat(max_batch_size, 1)
209

210
211
212
213
214
215
216
217
218
219
220
    def _get_positions(self, num_tokens: int):
        if self.uses_mrope:
            return self.mrope_positions[:, :num_tokens]
        return self.positions[:num_tokens]

    def _set_positions(self, num_tokens: int, positions: torch.Tensor):
        if self.uses_mrope:
            self.mrope_positions[:, :num_tokens] = positions
        else:
            self.positions[:num_tokens] = positions

221
222
223
224
    def propose(
        self,
        # [num_tokens]
        target_token_ids: torch.Tensor,
225
        # [num_tokens] or [3, num_tokens] when M-RoPE is enabled
226
227
228
229
230
        target_positions: torch.Tensor,
        # [num_tokens, hidden_size]
        target_hidden_states: torch.Tensor,
        # [batch_size]
        next_token_ids: torch.Tensor,
231
        last_token_indices: torch.Tensor | None,
232
        common_attn_metadata: CommonAttentionMetadata,
233
        sampling_metadata: SamplingMetadata,
234
        mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None,
235
    ) -> torch.Tensor:
236
237
        num_tokens = target_token_ids.shape[0]
        batch_size = next_token_ids.shape[0]
238
239
240

        if last_token_indices is None:
            last_token_indices = common_attn_metadata.query_start_loc[1:] - 1
241

242
243
244
        if self.method == "eagle3":
            assert isinstance(self.model, Eagle3LlamaForCausalLM)
            target_hidden_states = self.model.combine_hidden_states(
245
246
                target_hidden_states
            )
247
            assert target_hidden_states.shape[-1] == self.hidden_size
248
249
        # Shift the input ids by one token.
        # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
250
        self.input_ids[: num_tokens - 1] = target_token_ids[1:]
251
252
        # Replace the last token with the next token.
        # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
253
        self.input_ids[last_token_indices] = next_token_ids
254

255
        assert self.runner is not None
Jiayi Yao's avatar
Jiayi Yao committed
256

257
258
259
260
261
        if self.attn_metadata_builder is None:
            attn_metadata_builder = self._get_attention_metadata_builder()
        else:
            attn_metadata_builder = self.attn_metadata_builder

262
        attn_metadata = attn_metadata_builder.build_for_drafting(
263
264
            common_attn_metadata=common_attn_metadata, draft_index=0
        )
265
266
267
268
269
270
        # FIXME: support hybrid kv for draft model (remove separate indexer)
        if self.draft_indexer_metadata_builder:
            draft_indexer_metadata = (
                self.draft_indexer_metadata_builder.build_for_drafting(
                    common_attn_metadata=common_attn_metadata,
                    draft_index=0,
271
272
                )
            )
273
274
        else:
            draft_indexer_metadata = None
275
276
277
278
279
        # At this moment, we assume all eagle layers belong to the same KV
        # cache group, thus using the same attention metadata.
        per_layer_attn_metadata = {}
        for layer_name in self.attn_layer_names:
            per_layer_attn_metadata[layer_name] = attn_metadata
280

281
282
283
284
        for layer_name in self.indexer_layer_names:
            assert draft_indexer_metadata is not None
            per_layer_attn_metadata[layer_name] = draft_indexer_metadata

Rémi Delacourt's avatar
Rémi Delacourt committed
285
286
287
288
289
        num_tokens_dp_padded, num_tokens_across_dp = self._pad_batch_across_dp(
            num_tokens_unpadded=num_tokens,
            num_tokens_padded=num_tokens,
        )

290
        cudagraph_runtime_mode = CUDAGraphMode.NONE
291
292
        if (
            self.use_cuda_graph
Rémi Delacourt's avatar
Rémi Delacourt committed
293
294
            and num_tokens_dp_padded
            <= self.compilation_config.max_cudagraph_capture_size
295
        ):
Rémi Delacourt's avatar
Rémi Delacourt committed
296
            num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens_dp_padded)
297
            cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
298
        else:
Rémi Delacourt's avatar
Rémi Delacourt committed
299
300
301
302
            num_input_tokens = num_tokens_dp_padded
        if num_tokens_across_dp is not None:
            num_tokens_across_dp[self.dp_rank] = num_input_tokens

303
        # copy inputs to buffer for cudagraph
304
        self._set_positions(num_tokens, target_positions)
305
        self.hidden_states[:num_tokens] = target_hidden_states
306
307
308
309

        if self.supports_mm_inputs:
            mm_embeds, is_mm_embed = mm_embed_inputs or (None, None)

310
            self.inputs_embeds[:num_tokens] = self.model.embed_input_ids(
311
312
313
                self.input_ids[:num_tokens],
                multimodal_embeddings=mm_embeds,
                is_multimodal=is_mm_embed,
314
            )
315

316
            input_ids = None
317
            inputs_embeds = self.inputs_embeds[:num_input_tokens]
318
319
        else:
            input_ids = self.input_ids[:num_input_tokens]
320
            inputs_embeds = None
321

322
        with set_forward_context(
323
324
325
            per_layer_attn_metadata,
            self.vllm_config,
            num_tokens=num_input_tokens,
Rémi Delacourt's avatar
Rémi Delacourt committed
326
            num_tokens_across_dp=num_tokens_across_dp,
327
            cudagraph_runtime_mode=cudagraph_runtime_mode,
328
        ):
Jiayi Yao's avatar
Jiayi Yao committed
329
            ret_hidden_states = self.model(
330
                input_ids=input_ids,
331
                positions=self._get_positions(num_input_tokens),
332
333
                hidden_states=self.hidden_states[:num_input_tokens],
                inputs_embeds=inputs_embeds,
334
            )
335
            if self.method == "mtp":
Jiayi Yao's avatar
Jiayi Yao committed
336
                last_hidden_states = ret_hidden_states
337
                hidden_states = last_hidden_states
Jiayi Yao's avatar
Jiayi Yao committed
338
339
            else:
                last_hidden_states, hidden_states = ret_hidden_states
340
        sample_hidden_states = last_hidden_states[last_token_indices]
341
        logits = self.model.compute_logits(sample_hidden_states)
342
343
344
345
346

        # Early exit if there is only one draft token to be generated.
        if self.num_speculative_tokens == 1:
            draft_token_ids = logits.argmax(dim=-1)
            return draft_token_ids.view(-1, 1)
347

348
349
350
351
        if self.uses_mrope:
            positions = target_positions[:, last_token_indices]
        else:
            positions = target_positions[last_token_indices]
352
353
354
355
356
357
        if self.method in (
            "deepseek_mtp",
            "ernie_mtp",
            "longcat_flash_mtp",
            "pangu_ultra_moe_mtp",
        ):
zhuwenwen's avatar
zhuwenwen committed
358
359
360
            hidden_states = self.hidden_states[last_token_indices]
        else:
            hidden_states = hidden_states[last_token_indices]
361
362
363

        if isinstance(attn_metadata, TreeAttentionMetadata):
            # Draft using tree attention.
364
365
366
367
368
369
370
371
372
373
            draft_token_ids_list = self.propose_tree(
                batch_size=batch_size,
                logits=logits,
                positions=positions,
                hidden_states=hidden_states,
                common_attn_metadata=common_attn_metadata,
            )
            # [batch_size, num_tree_tokens]
            return torch.cat(draft_token_ids_list, dim=1)

374
        draft_token_ids = logits.argmax(dim=-1)
375

376
377
378
        if self.allowed_attn_types is not None and not isinstance(
            attn_metadata, self.allowed_attn_types
        ):
379
380
381
382
            raise ValueError(
                f"Unsupported attention metadata type for speculative "
                "decoding with num_speculative_tokens > 1: "
                f"{type(attn_metadata)}. Supported types are: "
383
384
                f"{self.allowed_attn_types}"
            )
385

386
387
        # Generate the remaining draft tokens.
        draft_token_ids_list = [draft_token_ids]
388

Rémi Delacourt's avatar
Rémi Delacourt committed
389
390
391
392
393
        batch_size_dp_padded, batch_size_across_dp = self._pad_batch_across_dp(
            num_tokens_unpadded=batch_size,
            num_tokens_padded=batch_size,
        )

394
395
        if (
            self.use_cuda_graph
Rémi Delacourt's avatar
Rémi Delacourt committed
396
397
            and batch_size_dp_padded
            <= self.compilation_config.max_cudagraph_capture_size
398
        ):
Rémi Delacourt's avatar
Rémi Delacourt committed
399
            input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size_dp_padded)
400
            cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
401
        else:
Rémi Delacourt's avatar
Rémi Delacourt committed
402
            input_batch_size = batch_size_dp_padded
403
            cudagraph_runtime_mode = CUDAGraphMode.NONE
Rémi Delacourt's avatar
Rémi Delacourt committed
404
405
        if batch_size_across_dp is not None:
            batch_size_across_dp[self.dp_rank] = input_batch_size
406

407
408
        common_attn_metadata.num_actual_tokens = batch_size
        common_attn_metadata.max_query_len = 1
409
        common_attn_metadata.query_start_loc = self.arange[: batch_size + 1]
410
        common_attn_metadata.query_start_loc_cpu = torch.from_numpy(
411
412
            self.token_arange_np[: batch_size + 1]
        ).clone()
413
        for token_index in range(self.num_speculative_tokens - 1):
414
            # Update the inputs.
415
416
417
            # cast to int32 is crucial when eagle model is compiled.
            # tensor.argmax() returns int64 by default.
            input_ids = draft_token_ids_list[-1].int()
418
419
420
421
422
423
424
425
426
427
428
429
430
            if self.uses_mrope:
                positions += 1
                # NOTE(woosuk): We should handle the case where the draft model
                # generates tokens beyond the max model length.
                # Since it is complex to remove such requests from the batch,
                # we keep them in the batch but adjust the position ids
                # and slot mappings to avoid the
                # out-of-range access during the model execution.
                # The draft tokens generated with this adjustment
                # should be ignored.
                exceeds_max_model_len = positions[0] >= self.max_model_len
                # Mask out the position ids that exceed the max model length.
                # Otherwise, we may get out-of-range error in RoPE.
431
432
433
434
435
                clamped_positions = torch.where(
                    exceeds_max_model_len.unsqueeze(0),
                    torch.zeros_like(positions),
                    positions,
                )
436
            else:
437
438
                positions += 1
                exceeds_max_model_len = positions >= self.max_model_len
439
                clamped_positions = torch.where(exceeds_max_model_len, 0, positions)
440
441
442
            # For data integrity when async scheduling, we shouldn't use in place
            # operations in case they are modified in next step's `prepare_input`
            # of main model.
443
            # Increment the sequence lengths.
444
            common_attn_metadata.seq_lens += 1
445
446
            # This is an out-of-place operation to avoid modifying the original tensor.
            common_attn_metadata.seq_lens_cpu = common_attn_metadata.seq_lens_cpu + 1
447
448
            # For the requests that exceed the max model length, we set the
            # sequence length to 1 to minimize their overheads in attention.
449

450
            common_attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1)
451

452
            common_attn_metadata.num_computed_tokens_cpu = (
453
                common_attn_metadata.seq_lens_cpu - 1
454
            )
455

456
            # Compute the slot mapping.
457
458
459
460
461
            if self.uses_mrope:
                # all dimensions of positions are the same
                block_numbers = clamped_positions[0] // self.block_size
            else:
                block_numbers = clamped_positions // self.block_size
462
            block_ids = common_attn_metadata.block_table_tensor.gather(
463
464
                dim=1, index=block_numbers.view(-1, 1)
            )
465
            block_ids = block_ids.view(-1)
466
467
            if self.uses_mrope:
                common_attn_metadata.slot_mapping = (
468
469
                    block_ids * self.block_size + clamped_positions[0] % self.block_size
                )
470
471
            else:
                common_attn_metadata.slot_mapping = (
472
473
                    block_ids * self.block_size + clamped_positions % self.block_size
                )
474
475
476
            # Mask out the slot mappings that exceed the max model length.
            # Otherwise, the KV cache will be inadvertently updated with the
            # padding tokens.
477
            common_attn_metadata.slot_mapping.masked_fill_(
478
479
                exceeds_max_model_len, PADDING_SLOT_ID
            )
480
481

            # Rebuild attention metadata
482
            attn_metadata = attn_metadata_builder.build_for_drafting(  # type: ignore
483
484
                common_attn_metadata=common_attn_metadata, draft_index=token_index + 1
            )
485
486
            for layer_name in self.attn_layer_names:
                per_layer_attn_metadata[layer_name] = attn_metadata
487

488
489
            # copy inputs to buffer for cudagraph
            self.input_ids[:batch_size] = input_ids
490
            self._set_positions(batch_size, clamped_positions)
491
            self.hidden_states[:batch_size] = hidden_states
492
            if self.supports_mm_inputs:
493
                self.inputs_embeds[:batch_size] = self.model.embed_input_ids(input_ids)
494

495
                input_ids = None
496
                inputs_embeds = self.inputs_embeds[:input_batch_size]
497
498
            else:
                input_ids = self.input_ids[:input_batch_size]
499
                inputs_embeds = None
500

501
            # Run the model.
502
            with set_forward_context(
503
504
505
                per_layer_attn_metadata,
                self.vllm_config,
                num_tokens=input_batch_size,
Rémi Delacourt's avatar
Rémi Delacourt committed
506
                num_tokens_across_dp=batch_size_across_dp,
507
                cudagraph_runtime_mode=cudagraph_runtime_mode,
508
            ):
509
                ret_hidden_states = self.model(
510
                    input_ids=input_ids,
511
                    positions=self._get_positions(input_batch_size),
512
513
                    hidden_states=self.hidden_states[:input_batch_size],
                    inputs_embeds=inputs_embeds,
514
                )
515
                if self.method == "mtp":
516
                    last_hidden_states = ret_hidden_states
517
                    hidden_states = ret_hidden_states
518
519
                else:
                    last_hidden_states, hidden_states = ret_hidden_states
520
            hidden_states = hidden_states[:batch_size]
521
            logits = self.model.compute_logits(last_hidden_states[:batch_size])
522
            draft_token_ids = logits.argmax(dim=-1)
523
524
525
526
            draft_token_ids_list.append(draft_token_ids)

        # [batch_size, num_speculative_tokens]
        draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
527
        return draft_token_ids
528

529
    def prepare_next_token_ids_cpu(
530
        self,
531
        sampled_token_ids: list[list[int]],
532
533
534
535
        requests: dict[str, CachedRequestState],
        gpu_input_batch: InputBatch,
        num_scheduled_tokens: dict[str, int],
    ) -> torch.Tensor:
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
        """
        This function is used to prepare the inputs for speculative decoding.
        It calculates the next token ids for each request based on the sampled
        token ids from the CPU. If a request has no sampled token ids (e.g.,
        during the initial decoding steps), it falls back to using the request
        state to get the next token id.
        """
        req_ids = gpu_input_batch.req_ids
        next_token_ids: list[int] = []
        for i, token_ids in enumerate(sampled_token_ids):
            if token_ids:
                # Common case.
                next_token_id = token_ids[-1]
            else:
                # Partial prefill (rare case).
                # Get the next token id from the request state.
                req_id = req_ids[i]
                req_state = requests[req_id]
554
                seq_len = req_state.num_computed_tokens + num_scheduled_tokens[req_id]
555
556
                next_token_id = req_state.get_token_id(seq_len)
            next_token_ids.append(next_token_id)
557
        next_token_ids = torch.tensor(
558
559
            next_token_ids, dtype=torch.int32, device=self.input_ids.device
        )
560
561
        return next_token_ids

562
563
564
565
566
567
    def prepare_next_token_ids_padded(
        self,
        common_attn_metadata: CommonAttentionMetadata,
        sampled_token_ids: torch.Tensor,
        requests: dict[str, CachedRequestState],
        gpu_input_batch: InputBatch,
568
        discard_request_mask: torch.Tensor,
569
    ) -> tuple[torch.Tensor, torch.Tensor]:
570
571
572
573
        """
        This function is used to prepare the inputs for speculative decoding.
        It calculates the next token ids and the number of valid sampled tokens
        for each request, considering the "discarded" requests whose next token
574
575
        is not sampled and comes from `request.get_token_id()` instead. This is denoted
        the "backup" token id. It also counts rejected tokens via `sampled_token_ids`.
576
577
578
        """
        # Precompute get_token_id for when there is no valid next token
        num_reqs = gpu_input_batch.num_reqs
579
580
581
582
583
584
        self.backup_next_token_ids.np[:num_reqs] = np.array(
            [
                requests[gpu_input_batch.req_ids[i]].get_token_id(
                    common_attn_metadata.seq_lens_cpu[i].item()
                )
                for i in range(num_reqs)
585
586
            ],
            dtype=np.int32,
587
        )
588
        self.backup_next_token_ids.copy_to_gpu(num_reqs)
589
        backup_tokens_gpu = self.backup_next_token_ids.gpu
590

591
592
        batch_size, num_tokens = sampled_token_ids.shape
        device = sampled_token_ids.device
593

594
595
        assert discard_request_mask.dtype == torch.bool
        assert backup_tokens_gpu.dtype == torch.int32
596

597
598
599
        next_token_ids = torch.empty((batch_size,), dtype=torch.int32, device=device)
        valid_sampled_tokens_count = torch.empty(
            (batch_size,), dtype=torch.int32, device=device
600
        )
601

602
603
        # Kernel grid: one program per request (row)
        grid = (batch_size,)
604

605
606
607
608
609
610
611
612
613
614
615
616
617
        # Find the next power of 2 for block sizes
        BLOCK_SIZE_TOKENS = triton.next_power_of_2(num_tokens)
        eagle_prepare_next_token_padded_kernel[grid](
            sampled_token_ids,
            discard_request_mask,
            backup_tokens_gpu,
            next_token_ids,
            valid_sampled_tokens_count,
            gpu_input_batch.vocab_size,
            num_tokens,
            batch_size,
            sampled_token_ids.stride(0),
            BLOCK_SIZE_TOKENS=BLOCK_SIZE_TOKENS,
618
        )
619
620
621

        return next_token_ids, valid_sampled_tokens_count

622
623
624
625
626
    def prepare_inputs_padded(
        self,
        common_attn_metadata: CommonAttentionMetadata,
        spec_decode_metadata: SpecDecodeMetadata,
        valid_sampled_tokens_count: torch.Tensor,
627
    ) -> tuple[CommonAttentionMetadata, torch.Tensor]:
628
629
630
631
632
633
634
        """
        This function is used to prepare the inputs for speculative decoding
        It updates the common_attn_metadata for speculative decoding,
        but does not consider the rejected tokens. Instead, all tokens
        are included as inputs to the speculator, with the rejected tokens
        used as padding and filtered out later by `token_indices_to_sample`.
        """
635
636
        num_reqs = common_attn_metadata.num_reqs
        device = valid_sampled_tokens_count.device
637

638
639
        token_indices_to_sample = torch.empty(
            (num_reqs,), dtype=torch.int32, device=device
640
        )
641

642
643
644
645
646
647
648
649
        # Kernel grid: one program per request (row)
        grid = (num_reqs,)
        eagle_prepare_inputs_padded_kernel[grid](
            spec_decode_metadata.cu_num_draft_tokens,
            valid_sampled_tokens_count,
            common_attn_metadata.query_start_loc,
            token_indices_to_sample,
            num_reqs,
650
        )
651
652

        query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
653
        new_query_len_per_req = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
654
655
656
657
658
659
660
661

        total_num_tokens = query_start_loc_cpu[-1].item()

        spec_common_attn_metadata = CommonAttentionMetadata(
            query_start_loc=common_attn_metadata.query_start_loc,
            seq_lens=common_attn_metadata.seq_lens,
            query_start_loc_cpu=query_start_loc_cpu,
            seq_lens_cpu=common_attn_metadata.seq_lens_cpu,
662
            num_computed_tokens_cpu=common_attn_metadata.num_computed_tokens_cpu,
663
664
665
666
667
            num_reqs=common_attn_metadata.num_reqs,
            num_actual_tokens=total_num_tokens,
            max_query_len=new_query_len_per_req.max().item(),
            max_seq_len=common_attn_metadata.seq_lens_cpu.max().item(),
            block_table_tensor=common_attn_metadata.block_table_tensor,
668
            slot_mapping=common_attn_metadata.slot_mapping[:total_num_tokens],
669
            causal=True,
670
            dcp_local_seq_lens=common_attn_metadata.dcp_local_seq_lens,
671
672
        )

673
        return spec_common_attn_metadata, token_indices_to_sample
674

675
676
677
678
679
680
681
682
683
684
685
    def propose_tree(
        self,
        batch_size: int,
        # [num_tokens, vocab_size]
        logits: torch.Tensor,
        # [num_tokens]
        positions: torch.Tensor,
        # [num_tokens, hidden_size]
        hidden_states: torch.Tensor,
        common_attn_metadata: CommonAttentionMetadata,
    ) -> list[torch.Tensor]:
686
687
688
689
        tree_attn_metadata_builder = self.runner.attn_groups[0][
            0
        ].get_metadata_builder()
        assert isinstance(tree_attn_metadata_builder, TreeAttentionMetadataBuilder)
690

691
        total_num_drafts = self.cu_drafts_per_level[0]
692
693
        level_num_drafts = total_num_drafts
        # Sample a draft token for each child at the tree root level.
694
        num_children = self.child_drafts_per_level[0]
695
696
697
        if num_children == 1:
            draft_token_ids = logits.argmax(dim=-1).view(batch_size, -1)
        else:
698
699
700
            draft_token_ids = torch.topk(logits, num_children, dim=-1).indices.view(
                batch_size, -1
            )
701
702
703
704
        draft_token_ids_list = [draft_token_ids]
        draft_hidden_states = hidden_states.view(batch_size, 1, -1)

        # Initialize empty tensors for concatenation with the level outputs.
705
706
707
708
709
710
711
712
713
        tree_input_ids = torch.empty(
            0, device=self.input_ids.device, dtype=self.input_ids.dtype
        )
        tree_positions = torch.empty(
            0, device=self.positions.device, dtype=self.positions.dtype
        )
        tree_hidden_states = torch.empty(
            0, device=self.hidden_states.device, dtype=self.hidden_states.dtype
        )
714
715
        # Precompute the draft token positions.
        flattened_draft_positions = (
716
717
            positions.view(batch_size, -1) + self.tree_draft_pos_offsets[:batch_size, :]
        )
718
        tree_depth = len(self.cu_drafts_per_level)
719
        for level in range(tree_depth - 1):
720
721
            # Get draft positions for RoPE.
            draft_positions = positions + (level + 1)
722
            exceeds_max_model_len = (positions + total_num_drafts) >= self.max_model_len
723
724
            # Mask out the position ids that exceed the max model length.
            # Otherwise, we may get out-of-range error in RoPE.
725
            draft_positions = torch.where(
726
727
728
                exceeds_max_model_len,
                0,
                draft_positions,
729
730
            ).view(batch_size, -1)

731
732
            if level_num_drafts > 1:
                # Repeat the positions for each draft at this level.
733
                draft_positions = draft_positions.repeat_interleave(
734
735
                    level_num_drafts, dim=1
                )
736
737
738
739

            if num_children > 1:
                # Repeat draft hidden states for each child.
                draft_hidden_states = draft_hidden_states.repeat_interleave(
740
741
                    num_children, dim=1
                )
742
743

            # Concatenate the draft tokens, positions, and hidden states.
744
745
            tree_input_ids = torch.cat([tree_input_ids, draft_token_ids], dim=1)
            tree_positions = torch.cat([tree_positions, draft_positions], dim=1)
746
            tree_hidden_states = torch.cat(
747
748
                [tree_hidden_states, draft_hidden_states], dim=1
            )
749
750
751

            # Build new attention metadata for the next level of drafts.
            # This is necessary to support tree attention.
752
            query_len = total_num_drafts
753
754
            common_attn_metadata = replace(
                common_attn_metadata,
755
                query_start_loc=query_len * self.arange[: batch_size + 1],
756
757
758
759
760
761
                seq_lens=common_attn_metadata.seq_lens + level_num_drafts,
                num_actual_tokens=batch_size * query_len,
                max_query_len=query_len,
            )
            attn_metadata = tree_attn_metadata_builder.build_for_drafting(
                common_attn_metadata=common_attn_metadata,
762
                draft_index=level + 1,
763
764
765
766
767
768
769
770
            )

            # Apply new attention metadata to all layers.
            per_layer_attn_metadata = {}
            for layer_name in self.attn_layer_names:
                per_layer_attn_metadata[layer_name] = attn_metadata

            # Consider max model length.
771
772
773
            attn_metadata.max_seq_len = min(
                attn_metadata.max_seq_len, self.max_model_len
            )
774
775
776
777
778
            # For the requests that exceed the max model length, we set the
            # sequence length to 1 to minimize their overheads in attention.
            attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1)

            # Compute the slot mapping.
779
            query_positions = flattened_draft_positions[:, level : level + query_len]
780
            block_numbers = query_positions // self.block_size
781
782
783
784
            block_ids = attn_metadata.block_table.gather(dim=1, index=block_numbers)
            slot_mapping = (
                block_ids * self.block_size + query_positions % self.block_size
            )
785
786
787
788
789
790
791
792
793
794
795
            # Mask out the slot mappings that exceed the max model length.
            # Otherwise, the KV cache will be inadvertently updated with the
            # padding tokens.
            slot_mapping[exceeds_max_model_len] = PADDING_SLOT_ID
            attn_metadata.slot_mapping = slot_mapping.view(-1)

            # Copy inputs to buffer for cudagraph.
            num_tokens = attn_metadata.num_actual_tokens
            input_ids = tree_input_ids.view(-1)
            self.input_ids[:num_tokens] = input_ids
            self.positions[:num_tokens] = tree_positions.view(-1)
796
            self.hidden_states[:num_tokens] = tree_hidden_states.view(num_tokens, -1)
797

798
799
800
801
            if (
                self.use_cuda_graph
                and num_tokens <= self.compilation_config.max_cudagraph_capture_size
            ):
802
                num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
803
                cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
804
805
            else:
                num_input_tokens = num_tokens
806
                cudagraph_runtime_mode = CUDAGraphMode.NONE
807
            # Run the model.
808
            with set_forward_context(
809
810
811
812
                per_layer_attn_metadata,
                self.vllm_config,
                num_tokens=num_input_tokens,
                cudagraph_runtime_mode=cudagraph_runtime_mode,
813
            ):
814
815
816
817
818
819
820
821
822
                last_hidden_states, hidden_states = self.model(
                    input_ids=self.input_ids[:num_input_tokens],
                    positions=self.positions[:num_input_tokens],
                    hidden_states=self.hidden_states[:num_input_tokens],
                    inputs_embeds=None,
                )

            # Get the output hidden states for the draft tokens.
            draft_hidden_states = hidden_states[:num_tokens].view(
823
824
                batch_size, query_len, -1
            )[:, -level_num_drafts:]
825
            draft_last_hidden_states = last_hidden_states[:num_tokens].view(
826
827
                batch_size, query_len, -1
            )[:, -level_num_drafts:]
828
829
830

            # Get the output logits for the draft tokens.
            logits = self.model.compute_logits(
831
832
                draft_last_hidden_states.reshape(batch_size * level_num_drafts, -1)
            )
833
834
835
836
837
838

            # Sample a draft token for each child at the next tree level.
            num_children = self.child_drafts_per_level[level + 1]
            if num_children == 1:
                draft_token_ids = logits.argmax(dim=-1).view(batch_size, -1)
            else:
839
840
841
                draft_token_ids = torch.topk(logits, num_children, dim=-1).indices.view(
                    batch_size, -1
                )
842
843
844
            draft_token_ids_list.append(draft_token_ids)

            # Update the # drafts counters for the next tree level.
845
            level_num_drafts = self.cu_drafts_per_level[level + 1] - total_num_drafts
846
847
848
            total_num_drafts = self.cu_drafts_per_level[level + 1]
        return draft_token_ids_list

849
    def prepare_inputs(
850
851
        self,
        common_attn_metadata: CommonAttentionMetadata,
852
853
        sampled_token_ids: list[list[int]],
        num_draft_tokens: list[int],
854
855
    ) -> tuple[CommonAttentionMetadata, torch.Tensor]:
        """
856
        This function is used to prepare the inputs for speculative decoding.
857
858
859
860
861
862
        It updates to the common_attn_metadata to account for the rejected
        tokens (and newly sampled tokens). It also returns the token indices
        of the tokens that should be fed to the speculator.
        """
        # E.g.
        #  common_attn_metadata.query_start_loc{_cpu}:
863
        #       [0, q1, q1 + q2, q1 + q2 + q3]
864
865
866
867
868
869
        #  common_attn_metadata.seq_lens{_cpu}: [s1, s2, s3]
        #  num_rejected_tokens: [n1, n2, n3]
        # This function computes the intermediate values:
        #  num_tokens_per_req: [q1 - n1, q2 - n2, q3 - n3]
        # And returns:
        #  common_attn_metadata.query_start_loc{_cpu}:
870
        #       [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
871
        #  common_attn_metadata.seq_lens{_cpu}:
872
        #       [s1 - n1 + 1, s2 - n2 + 1, s3 - n3 + 1]
873
        #  token_indices: [0, 1, ..., q1 - n1 - 1,
874
875
        #                 q1, q1 + 1, ..., q1 + q2 - n2 - 1,
        #                 q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1]
876

877
878
879
880
        num_rejected_tokens = [
            n + 1 - len(sampled_token_ids[i]) if n > 0 else 0
            for i, n in enumerate(num_draft_tokens)
        ]
881
        num_rejected_tokens = torch.tensor(num_rejected_tokens, dtype=torch.int32)
882

883
884
        device = common_attn_metadata.query_start_loc.device
        query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
885
        new_seq_lens_cpu = common_attn_metadata.seq_lens_cpu - num_rejected_tokens
886
887

        # [0, q1, q1 + q2, q1 + q2 + q3] -> [q1, q2, q3]
888
        new_query_len_per_req = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
889
890
891
892
893
894
895
896
        # [q1, q2, q3] -> [q1 - n1, q2 - n2, q3 - n3]
        new_num_tokens_per_req = new_query_len_per_req - num_rejected_tokens
        new_num_tokens_per_req_np = new_num_tokens_per_req.numpy()

        # [q1 - n1, q2 - n2, q3 - n3] ->
        # [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
        new_query_start_loc_cpu = torch.zeros(
            query_start_loc_cpu.shape,
897
            dtype=torch.int32,
898
899
            pin_memory=is_pin_memory_available(),
        )
900
901
902
903
904
905
906
907
908
        new_query_start_loc_np = new_query_start_loc_cpu.numpy()
        np.cumsum(new_num_tokens_per_req_np, out=new_query_start_loc_np[1:])

        total_num_tokens = new_query_start_loc_np[-1]
        # Example assuming num_tokens_per_req_np = [2, 4, 3]
        # this implies that `new_query_start_locs` is:
        # [0, 2, 6, 9] ->
        # [0, 0, 2, 2, 2, 2, 6, 6, 6]
        #  _r1_  ____r2____  ___r3__
909
910
911
        new_query_start_locs_expanded = np.repeat(
            new_query_start_loc_np[:-1], new_num_tokens_per_req_np
        )
912
913
914
        # [0, 1, 2, 3, 4, 5, 6, 7, 8] ->
        # [0, 1, 0, 1, 2, 3, 0, 1, 2]
        #  _r1_  ____r2____  ___r3__
915
916
917
        token_offests = (
            self.token_arange_np[:total_num_tokens] - new_query_start_locs_expanded
        )
918
919
920
921
922
923

        # Expand starting positions to match token pattern
        # [0, q1, q1 + q2] ->
        # [0, 0, q1, q1, q1, q1, q1 + q2, q1 + q2, q1 + q2]
        #  _r1_  _____r2_______  ___________r3____________
        old_query_start_locs_expanded = np.repeat(
924
925
            query_start_loc_cpu[:-1].numpy(), new_num_tokens_per_req_np
        )
926
        # Final token indices are:
927
928
929
        # [0, 1,                                // req 1
        #  q1 + 0, q1 + 1, q1 + 2, q1 + 3,       // req 2
        #  q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3
930
        token_indices_np = token_offests + old_query_start_locs_expanded
931
        token_indices = torch.from_numpy(token_indices_np).to(device, non_blocking=True)
932
933

        spec_common_attn_metadata = CommonAttentionMetadata(
934
            query_start_loc=new_query_start_loc_cpu.to(device, non_blocking=True),
935
936
937
            seq_lens=new_seq_lens_cpu.to(device, non_blocking=True),
            query_start_loc_cpu=new_query_start_loc_cpu,
            seq_lens_cpu=new_seq_lens_cpu,
938
            num_computed_tokens_cpu=common_attn_metadata.num_computed_tokens_cpu,
939
940
941
            num_reqs=common_attn_metadata.num_reqs,
            num_actual_tokens=total_num_tokens,
            max_query_len=new_query_len_per_req.max().item(),
942
            max_seq_len=new_seq_lens_cpu.max().item(),
943
944
            block_table_tensor=common_attn_metadata.block_table_tensor,
            slot_mapping=common_attn_metadata.slot_mapping[token_indices],
945
            causal=True,
946
            dcp_local_seq_lens=common_attn_metadata.dcp_local_seq_lens,
947
        )
948
949

        return spec_common_attn_metadata, token_indices
950

951
    def get_model_name(self, model: nn.Module) -> str:
952
        if hasattr(model, "module"):  # multi-GPU
953
954
955
            model = model.module
        return model.__class__.__name__

956
    def load_model(self, target_model: nn.Module) -> None:
957
        draft_model_config = self.vllm_config.speculative_config.draft_model_config
958
        target_attn_layer_names = set(
959
            get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase).keys()
960
        )
961
962
        # FIXME: support hybrid kv for draft model
        target_indexer_layer_names = set(
963
964
965
966
            get_layers_from_vllm_config(
                self.vllm_config, DeepseekV32IndexerCache
            ).keys()
        )
967

968
        from vllm.compilation.backends import set_model_tag
969

970
        with set_model_tag("eagle_head"):
971
972
973
            self.model = get_model(
                vllm_config=self.vllm_config, model_config=draft_model_config
            )
974

975
        draft_attn_layer_names = (
976
            get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase).keys()
977
978
979
980
981
982
            - target_attn_layer_names
        )
        indexer_layers = get_layers_from_vllm_config(
            self.vllm_config, DeepseekV32IndexerCache
        )
        draft_indexer_layer_names = indexer_layers.keys() - target_indexer_layer_names
983
        self.attn_layer_names = list(draft_attn_layer_names - draft_indexer_layer_names)
984
985
986
987
988
        self.indexer_layer_names = list(draft_indexer_layer_names)

        if self.indexer_layer_names:
            first_layer = self.indexer_layer_names[0]
            self.draft_indexer_metadata_builder = (
989
990
991
                indexer_layers[first_layer]
                .get_attn_backend()
                .get_builder_cls()(
992
                    indexer_layers[first_layer].get_kv_cache_spec(self.vllm_config),
993
994
995
                    self.indexer_layer_names,
                    self.vllm_config,
                    self.device,
996
997
                )
            )
998
999
        else:
            self.draft_indexer_metadata_builder = None
1000

1001
        if self.supports_mm_inputs:
1002
1003
1004
            # Even if the target model is multimodal, we can also use
            # text-only draft models
            try:
1005
                dummy_input_ids = torch.tensor([[1]], device=self.input_ids.device)
1006
                self.model.embed_input_ids(dummy_input_ids, multimodal_embeddings=None)
1007
1008
1009
            except (NotImplementedError, AttributeError, TypeError):
                logger.warning(
                    "Draft model does not support multimodal inputs, "
1010
1011
                    "falling back to text-only mode"
                )
1012
                self.supports_mm_inputs = False
1013

1014
1015
        if supports_multimodal(target_model):
            # handle multimodality
1016
1017
1018
1019
            if self.get_model_name(target_model) in [
                "Qwen2_5_VLForConditionalGeneration",
                "Qwen3VLForConditionalGeneration",
            ]:
1020
                self.model.config.image_token_index = target_model.config.image_token_id
1021
1022
1023
1024
            elif self.get_model_name(target_model) == "PixtralForConditionalGeneration":
                self.model.config.image_token_index = (
                    target_model.config.vision_config.image_token_id
                )
1025
1026
            else:
                self.model.config.image_token_index = (
1027
1028
                    target_model.config.image_token_index
                )
1029
1030
1031
            target_language_model = target_model.get_language_model()
        else:
            target_language_model = target_model
1032

1033
        # share embed_tokens with the target model if needed
1034
        if get_pp_group().world_size == 1:
1035
            if hasattr(target_language_model.model, "embed_tokens"):
1036
                target_embed_tokens = target_language_model.model.embed_tokens
1037
            elif hasattr(target_language_model.model, "embedding"):
1038
1039
1040
                target_embed_tokens = target_language_model.model.embedding
            else:
                raise AttributeError(
1041
1042
                    "Target model does not have 'embed_tokens' or 'embedding' attribute"
                )
1043

1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
            share_embeddings = False
            if hasattr(self.model, "has_own_embed_tokens"):
                # EAGLE model
                if not self.model.has_own_embed_tokens:
                    share_embeddings = True
                    logger.info(
                        "Detected EAGLE model without its own embed_tokens in the"
                        " checkpoint. Sharing target model embedding weights with the"
                        " draft model."
                    )
                elif (
                    isinstance(target_embed_tokens.weight, torch.Tensor)
                    and isinstance(self.model.model.embed_tokens.weight, torch.Tensor)
1057
1058
1059
                    # TODO: Offload to CPU for comparison to avoid extra GPU memory
                    # usage in CI testing environments with limited GPU memory
                    and torch.equal(
1060
1061
                        target_embed_tokens.weight.cpu(),
                        self.model.model.embed_tokens.weight.cpu(),
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
                    )
                ):
                    share_embeddings = True
                    logger.info(
                        "Detected EAGLE model with embed_tokens identical to the target"
                        " model. Sharing target model embedding weights with the draft"
                        " model."
                    )
                else:
                    logger.info(
                        "Detected EAGLE model with distinct embed_tokens weights. "
                        "Keeping separate embedding weights from the target model."
                    )
1075
            else:
1076
1077
                # MTP model
                share_embeddings = True
1078
                logger.info(
1079
1080
                    "Detected MTP model. "
                    "Sharing target model embedding weights with the draft model."
1081
                )
1082
1083
1084
1085
1086

            if share_embeddings:
                if hasattr(self.model.model, "embed_tokens"):
                    del self.model.model.embed_tokens
                self.model.model.embed_tokens = target_embed_tokens
1087
        else:
1088
            logger.info(
1089
                "The draft model's vocab embedding will be loaded separately"
1090
1091
                " from the target model."
            )
1092
1093

        # share lm_head with the target model if needed
1094
1095
1096
1097
1098
        share_lm_head = False
        if hasattr(self.model, "has_own_lm_head"):
            # EAGLE model
            if not self.model.has_own_lm_head:
                share_lm_head = True
1099
                logger.info(
1100
1101
1102
1103
1104
1105
1106
                    "Detected EAGLE model without its own lm_head in the checkpoint. "
                    "Sharing target model lm_head weights with the draft model."
                )
            elif (
                hasattr(target_language_model, "lm_head")
                and isinstance(target_language_model.lm_head.weight, torch.Tensor)
                and isinstance(self.model.lm_head.weight, torch.Tensor)
1107
1108
                # TODO: Offload to CPU for comparison to avoid extra GPU memory
                # usage in CI testing environments with limited GPU memory
1109
                and torch.equal(
1110
1111
                    target_language_model.lm_head.weight.cpu(),
                    self.model.lm_head.weight.cpu(),
1112
                )
1113
            ):
1114
                share_lm_head = True
1115
                logger.info(
1116
1117
                    "Detected EAGLE model with lm_head identical to the target model. "
                    "Sharing target model lm_head weights with the draft model."
1118
                )
1119
1120
            else:
                logger.info(
1121
1122
                    "Detected EAGLE model with distinct lm_head weights. "
                    "Keeping separate lm_head weights from the target model."
1123
                )
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
        else:
            # MTP model
            share_lm_head = True
            logger.info(
                "Detected MTP model. "
                "Sharing target model lm_head weights with the draft model."
            )

        if share_lm_head and hasattr(target_language_model, "lm_head"):
            if hasattr(self.model, "lm_head"):
                del self.model.lm_head
            self.model.lm_head = target_language_model.lm_head
1136

1137
1138
1139
1140
    @torch.inference_mode()
    def dummy_run(
        self,
        num_tokens: int,
1141
        use_cudagraphs=True,
Rémi Delacourt's avatar
Rémi Delacourt committed
1142
        is_graph_capturing=False,
1143
    ) -> None:
1144
1145
        # Determine if CUDA graphs should be used for this run.
        cudagraphs_enabled = use_cudagraphs and self.use_cuda_graph
1146

Rémi Delacourt's avatar
Rémi Delacourt committed
1147
1148
1149
1150
        # FIXME: when using tree-based specdec, adjust number of forward-passes
        # according to the depth of the tree.
        for fwd_idx in range(
            self.num_speculative_tokens if not is_graph_capturing else 1
1151
        ):
Rémi Delacourt's avatar
Rémi Delacourt committed
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
            if fwd_idx <= 1:
                num_tokens_dp_padded, num_tokens_across_dp = self._pad_batch_across_dp(
                    num_tokens_unpadded=num_tokens,
                    num_tokens_padded=num_tokens,
                )
                if (
                    cudagraphs_enabled
                    and num_tokens_dp_padded
                    <= self.compilation_config.max_cudagraph_capture_size
                ):
                    num_input_tokens = self.vllm_config.pad_for_cudagraph(
                        num_tokens_dp_padded
                    )
                else:
                    num_input_tokens = num_tokens_dp_padded
                if num_tokens_across_dp is not None:
                    num_tokens_across_dp[self.dp_rank] = num_input_tokens
1169

Rémi Delacourt's avatar
Rémi Delacourt committed
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
            with set_forward_context(
                None,
                self.vllm_config,
                num_tokens=num_input_tokens,
                num_tokens_across_dp=num_tokens_across_dp,
                cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE
                if cudagraphs_enabled
                else CUDAGraphMode.NONE,
            ):
                if self.supports_mm_inputs:
                    input_ids = None
                    inputs_embeds = self.inputs_embeds[:num_input_tokens]
                else:
                    input_ids = self.input_ids[:num_input_tokens]
                    inputs_embeds = None
1185

Rémi Delacourt's avatar
Rémi Delacourt committed
1186
1187
1188
1189
1190
1191
                self.model(
                    input_ids=input_ids,
                    positions=self._get_positions(num_input_tokens),
                    hidden_states=self.hidden_states[:num_input_tokens],
                    inputs_embeds=inputs_embeds,
                )
1192

1193
    def _get_attention_metadata_builder(self) -> AttentionMetadataBuilder:
1194
        """Find and return the attention metadata builders for EAGLE layers.
1195

1196
1197
        Returns:
            The metadata builders for EAGLE layers.
1198

1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
        Raises:
            AssertionError: If no metadata builders are found for EAGLE layers.
        """
        builder = None
        chosen_layer = self.attn_layer_names[0]

        for kv_cache_group in self.runner.attn_groups:
            for attn_group in kv_cache_group:
                if chosen_layer in attn_group.layer_names:
                    builder = attn_group.get_metadata_builder()
                    break
            if builder is not None:
                break

        assert builder is not None, (
1214
1215
            "Failed to find attention metadata builder for EAGLE layers."
        )
1216
1217
        return builder

1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
    def _get_eagle3_use_aux_hidden_state_from_config(self) -> bool:
        """
        Some eagle3 heads (e.g., nvidia/gpt-oss-120b-Eagle3-v2) do not use auxiliary
        hidden states and directly uses the last layer output just like eagle1.
        They might indicate this by setting "use_aux_hidden_state" to False
        inside the "eagle_config" dict of their hf_config.
        """
        if self.method != "eagle3":
            return False
        # Assume that eagle3 heads use aux hidden states by default
        use_aux_hidden_state = True
        eagle_config = getattr(self.draft_model_config.hf_config, "eagle_config", None)
        if eagle_config is not None:
            use_aux_hidden_state = eagle_config.get("use_aux_hidden_state", True)
        return use_aux_hidden_state

1234
    def validate_same_kv_cache_group(self, kv_cache_config: KVCacheConfig) -> None:
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
        """
        Validate that all eagle layers belong to the same KVCacheGroup.
        Need this assumption to ensure all eagle layers can use the
        same AttentionMetadata.
        May extend to multiple AttentionMetadata in the future.
        """
        kv_cache_groups: dict[str, int] = {}
        for id, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups):
            for layer_name in kv_cache_group.layer_names:
                kv_cache_groups[layer_name] = id
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
        assert (
            len(
                set(
                    [
                        kv_cache_groups[layer_name]
                        for layer_name in self.attn_layer_names
                    ]
                )
            )
            == 1
        ), "All eagle layers should belong to the same kv cache group"
1256

Rémi Delacourt's avatar
Rémi Delacourt committed
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
    def _pad_batch_across_dp(
        self,
        num_tokens_unpadded: int,
        num_tokens_padded: int,
    ) -> tuple[int, torch.Tensor]:
        # TODO(Flechman): support DBO ubatching
        ubatch_slices, num_toks_across_dp = coordinate_batch_across_dp(
            num_tokens_unpadded=num_tokens_unpadded,
            parallel_config=self.vllm_config.parallel_config,
            allow_microbatching=False,
            allow_dp_padding=self.use_cuda_graph,
            num_tokens_padded=num_tokens_padded,
            uniform_decode=None,
            num_scheduled_tokens_per_request=None,
        )
        assert ubatch_slices is None, "DBO ubatching not implemented for EAGLE"

        num_tokens_dp_padded = num_tokens_padded
        if num_toks_across_dp is not None:
            num_tokens_dp_padded = int(num_toks_across_dp[self.dp_rank].item())
        return num_tokens_dp_padded, num_toks_across_dp
1278
1279


1280
1281
1282
1283
# NOTE(woosuk): Currently, the below code is not used and we always use argmax
# to sample the draft tokens. We will use this after we find a way to manage
# the draft prob tensor.
# Refer to https://github.com/vllm-project/vllm/pull/16899 for the details.
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
# FIXME(woosuk): The logic here is duplicated with the main sampling code.
# We should refactor this to reuse the same sampling implementation.
def compute_probs_and_sample_next_token(
    logits: torch.Tensor,
    sampling_metadata: SamplingMetadata,
) -> tuple[torch.Tensor, torch.Tensor]:
    if sampling_metadata.all_greedy:
        # For greedy requests, draft_probs is not used in rejection sampling.
        # Therefore, we can just return the logits.
        probs = logits
        next_token_ids = logits.argmax(dim=-1)
        return next_token_ids, probs

1297
1298
1299
1300
1301
1302
1303
1304
1305
    assert sampling_metadata.temperature is not None

    # Use epsilon comparison to detect greedy sampling (temperature ~ 0.0)
    # consistent with sampler.py's _SAMPLING_EPS threshold
    temperature = sampling_metadata.temperature
    # Avoid division by zero if there are greedy requests.
    if not sampling_metadata.all_random:
        is_greedy = temperature < _SAMPLING_EPS
        temperature = torch.where(is_greedy, 1.0, temperature)
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
    logits.div_(temperature.view(-1, 1))
    probs = logits.softmax(dim=-1, dtype=torch.float32)

    # NOTE(woosuk): Currently, we ignore most of the sampling parameters in
    # generating the draft tokens. We only use the temperature. While this
    # could degrade the acceptance rate, it does not affect the distribution
    # of the generated tokens after rejection sampling.

    # TODO(woosuk): Consider seeds.
    q = torch.empty_like(probs)
    q.exponential_()
1317
1318
1319
    # NOTE(woosuk): We shouldn't use `probs.div_(q)` because the draft_probs
    # will be used later for rejection sampling.
    next_token_ids = probs.div(q).argmax(dim=-1).view(-1)
1320
1321
1322
1323
1324
1325
1326
1327
    if not sampling_metadata.all_random:
        greedy_token_ids = probs.argmax(dim=-1)
        next_token_ids = torch.where(
            is_greedy,
            greedy_token_ids,
            next_token_ids,
        )
    return next_token_ids, probs