eagle.py 75.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
import ast
from dataclasses import replace
5
from importlib.util import find_spec
6
from typing import cast
7

8
import numpy as np
9
10
11
import torch
import torch.nn as nn

12
13
14
15
16
from vllm.config import (
    CUDAGraphMode,
    VllmConfig,
    get_layers_from_vllm_config,
)
17
from vllm.distributed.parallel_state import get_pp_group
18
from vllm.forward_context import set_forward_context
19
from vllm.logger import init_logger
20
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
21
from vllm.model_executor.model_loader import get_model
22
from vllm.model_executor.models import supports_multimodal
23
from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache
24
from vllm.model_executor.models.interfaces import SupportsMultiModal
25
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
26
from vllm.multimodal import MULTIMODAL_REGISTRY
27
from vllm.platforms import current_platform
28
from vllm.triton_utils import triton
29
from vllm.utils.platform_utils import is_pin_memory_available
30
31
32
33
from vllm.v1.attention.backend import (
    AttentionMetadataBuilder,
    CommonAttentionMetadata,
)
34
from vllm.v1.attention.backends.registry import AttentionBackendEnum
35
36
37
38
from vllm.v1.attention.backends.tree_attn import (
    TreeAttentionMetadata,
    TreeAttentionMetadataBuilder,
)
39
from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata
40
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
41
from vllm.v1.kv_cache_interface import KVCacheConfig
42
from vllm.v1.sample.metadata import SamplingMetadata
43
from vllm.v1.sample.sampler import _SAMPLING_EPS
44
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
45
from vllm.v1.spec_decode.utils import (
46
47
48
    PADDING_SLOT_ID,
    compute_new_slot_mapping,
    copy_and_expand_eagle_inputs_kernel,
49
50
    eagle_prepare_inputs_padded_kernel,
    eagle_prepare_next_token_padded_kernel,
51
    extend_all_queries_by_N,
52
)
53
from vllm.v1.utils import CpuGpuBuffer
Rémi Delacourt's avatar
Rémi Delacourt committed
54
from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
55
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
56

57
58
logger = init_logger(__name__)

59

60
class SpecDecodeBaseProposer:
61
62
63
64
    def __init__(
        self,
        vllm_config: VllmConfig,
        device: torch.device,
65
        pass_hidden_states_to_model: bool,
Jiayi Yao's avatar
Jiayi Yao committed
66
        runner=None,
67
68
    ):
        self.vllm_config = vllm_config
69
        assert vllm_config.speculative_config is not None
70
71
72
        self.speculative_config = vllm_config.speculative_config
        self.draft_model_config = self.speculative_config.draft_model_config
        self.method = self.speculative_config.method
73
        self.pass_hidden_states_to_model = pass_hidden_states_to_model
74

Jiayi Yao's avatar
Jiayi Yao committed
75
        self.runner = runner
76
        self.device = device
77
        self.dtype = vllm_config.model_config.dtype
78
        self.max_model_len = vllm_config.model_config.max_model_len
Rémi Delacourt's avatar
Rémi Delacourt committed
79
        self.dp_rank = vllm_config.parallel_config.data_parallel_rank
80
        self.num_speculative_tokens = self.speculative_config.num_speculative_tokens
81

82
83
84
85
        # We need to get the hidden size from the draft model config because
        # the draft model's hidden size can be different from the target model's
        # hidden size (e.g., Llama 3.3 70B).
        self.hidden_size = self.draft_model_config.get_hidden_size()
86
        self.inputs_embeds_size = self.draft_model_config.get_inputs_embeds_size()
87

88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
        # Unifying eagle, draft model, and parallel drafting support
        self.parallel_drafting: bool = self.speculative_config.parallel_drafting
        self.extra_slots_per_request = (
            1 if not self.parallel_drafting else self.num_speculative_tokens
        )
        self.net_num_new_slots_per_request = self.extra_slots_per_request - (
            1 if self.pass_hidden_states_to_model else 0
        )
        self.needs_extra_input_slots = self.net_num_new_slots_per_request > 0

        self.parallel_drafting_token_id: int = 0
        self.parallel_drafting_hidden_state_tensor: torch.Tensor | None = None
        if self.parallel_drafting:
            self._init_parallel_drafting_params()

        max_batch_size = vllm_config.scheduler_config.max_num_seqs
104
        self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
105
106
        self.token_arange_np = np.arange(self.max_num_tokens)

107
108
109
        # Multi-modal data support
        self.mm_registry = MULTIMODAL_REGISTRY
        self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
110
            vllm_config.model_config
111
        )
112

113
114
        self.attn_metadata_builder: AttentionMetadataBuilder | None = None
        self.draft_indexer_metadata_builder: AttentionMetadataBuilder | None = None
115
116
        self.attn_layer_names: list[str] = []
        self.indexer_layer_names: list[str] = []
117
118
119
        self.eagle3_use_aux_hidden_state: bool = (
            self._get_eagle3_use_aux_hidden_state_from_config()
        )
120

121
        self.compilation_config = self.vllm_config.compilation_config
122
123
124
125
126
127

        # Cudagraph dispatcher for PIECEWISE-only dispatching in eagle.
        # Keys are initialized later via initialize_cudagraph_keys() called from
        # gpu_model_runner._check_and_update_cudagraph_mode after
        # adjust_cudagraph_sizes_for_spec_decode is called.
        self.cudagraph_dispatcher = CudagraphDispatcher(self.vllm_config)
128

129
        # persistent buffers for cuda graph
130
131
132
        self.input_ids = torch.zeros(
            self.max_num_tokens, dtype=torch.int32, device=device
        )
133
134
135
        # Use draft model's M-RoPE setting, not target model's
        # Draft models may be text-only even if target is multimodal
        self.uses_mrope = self.draft_model_config.uses_mrope
136
137
        self.uses_xdrope_dim = self.vllm_config.model_config.uses_xdrope_dim
        self.draft_uses_xdrope_dim = self.draft_model_config.uses_xdrope_dim
138
        if self.uses_mrope:
139
140
141
142
143
144
145
146
147
148
            # NOTE: `mrope_positions` is implemented with one additional dummy
            # position on purpose to make it non-contiguous so that it can work
            # with torch compile.
            # See detailed explanation in https://github.com/vllm-project/vllm/pull/12128#discussion_r1926431923

            # NOTE: When M-RoPE is enabled, position ids are 3D regardless of
            # the modality of inputs. For text-only inputs, each dimension has
            # identical position IDs, making M-RoPE functionally equivalent to
            # 1D-RoPE.
            # See page 5 of https://arxiv.org/abs/2409.12191
149
            self.mrope_positions = torch.zeros(
150
                (3, self.max_num_tokens + 1), dtype=torch.int64, device=device
151
            )
152
153
154
155
156
157
        elif self.uses_xdrope_dim > 0 and self.draft_uses_xdrope_dim > 0:
            self.xdrope_positions = torch.zeros(
                (self.uses_xdrope_dim, self.max_num_tokens + 1),
                dtype=torch.int64,
                device=device,
            )
158
159
        else:
            # RoPE need (max_num_tokens,)
160
161
162
            self.positions = torch.zeros(
                self.max_num_tokens, dtype=torch.int64, device=device
            )
163
        self.hidden_states = torch.zeros(
164
165
            (self.max_num_tokens, self.hidden_size), dtype=self.dtype, device=device
        )
166

167
168
169
        # We need +1 here because the arange is used to set query_start_loc,
        # which has one more element than batch_size.
        max_num_slots_for_arange = max(max_batch_size + 1, self.max_num_tokens)
170
171
172
        self.arange = torch.arange(
            max_num_slots_for_arange, device=device, dtype=torch.int32
        )
173

174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
        if self.needs_extra_input_slots:
            self._raise_if_padded_drafter_batch_disabled()
            self._raise_if_multimodal()
            self._raise_if_mrope()

        self.is_rejected_token_mask: torch.Tensor | None = None
        self.is_masked_token_mask: torch.Tensor | None = None
        if self.needs_extra_input_slots:
            # For draft models and parallel drafting, we need to keep track of
            # which tokens are rejected to update the slot mapping with padding slots.
            self.is_rejected_token_mask = torch.zeros(
                (self.max_num_tokens,), dtype=torch.bool, device=device
            )
            # For parallel drafting, we also need to keep track of which tokens
            # are parallel-padding tokens used to sample at later positions.
            # We populate this tensor even when using draft models for simplicity.
            self.is_masked_token_mask = torch.zeros(
                (self.max_num_tokens,), dtype=torch.bool, device=device
            )

194
        self.inputs_embeds = torch.zeros(
195
196
197
            (self.max_num_tokens, self.inputs_embeds_size),
            dtype=self.dtype,
            device=device,
198
        )
199

200
201
202
203
204
        self.backup_next_token_ids = CpuGpuBuffer(
            max_batch_size,
            dtype=torch.int32,
            pin_memory=is_pin_memory_available(),
            device=device,
205
206
            with_numpy=True,
        )
207

208
209
210
211
        self._slot_mapping_buffer = torch.zeros(
            self.max_num_tokens, dtype=torch.int64, device=device
        )

212
        # Determine allowed attention backends once during initialization.
213
        self.allowed_attn_types: tuple | None = None
214
        if current_platform.is_rocm():
215
216
217
218
219
220
            from vllm.v1.attention.backends.rocm_attn import RocmAttentionMetadata

            rocm_types = [
                TritonAttentionMetadata,
                RocmAttentionMetadata,
            ]
221
            # ROCM_AITER_FA is an optional backend
222
223
224
225
226
227
228
            # We check is_enabled() here to avoid importing the backend module during
            # auto-discovery when VLLM_ROCM_USE_AITER=0, which would trigger aiter
            # import and JIT compilation warnings. Explicit backend selection via
            # attention_config still works because the backend module is loaded
            # directly when selected, not through this auto-discovery path.
            # Check if backend module exists to allow explicit selection
            if find_spec(
229
230
                AttentionBackendEnum.ROCM_AITER_FA.get_path(include_classname=False)
            ):
231
                from vllm.v1.attention.backends.rocm_aiter_fa import (
232
233
234
                    AiterFlashAttentionMetadata,
                )

235
                rocm_types.append(AiterFlashAttentionMetadata)
236
237

            # TRITON_MLA backend support for MLA models (e.g., DeepSeek)
238
239
240
            from vllm.model_executor.layers.attention.mla_attention import (
                MLACommonMetadata,
            )
241
242
243

            rocm_types.append(MLACommonMetadata)

244
245
246
247
248
            # FlexAttention backend support
            from vllm.v1.attention.backends.flex_attention import FlexAttentionMetadata

            rocm_types.append(FlexAttentionMetadata)

249
250
            self.allowed_attn_types = tuple(rocm_types)

251
252
        # Parse the speculative token tree.
        spec_token_tree = self.speculative_config.speculative_token_tree
253
        assert spec_token_tree is not None
254
        self.tree_choices: list[tuple[int, ...]] = ast.literal_eval(spec_token_tree)
255
256
257
258
259
260
261
262
        tree_depth = len(self.tree_choices[-1])
        # Precompute per-level properties of the tree.
        num_drafts_per_level = [0] * tree_depth
        for node in self.tree_choices:
            num_drafts_per_level[len(node) - 1] += 1
        self.cu_drafts_per_level = [num_drafts_per_level[0]]
        self.child_drafts_per_level = [num_drafts_per_level[0]]
        for level in range(1, tree_depth):
263
264
265
266
267
268
            self.cu_drafts_per_level.append(
                self.cu_drafts_per_level[-1] + num_drafts_per_level[level]
            )
            self.child_drafts_per_level.append(
                num_drafts_per_level[level] // num_drafts_per_level[level - 1]
            )
269
270
        # Precompute draft position offsets in flattened tree.
        self.tree_draft_pos_offsets = torch.arange(
271
            1, len(self.tree_choices) + 1, device=device, dtype=torch.int32
272
273
        ).repeat(max_batch_size, 1)

274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
    def _raise_if_padded_drafter_batch_disabled(self):
        if self.speculative_config.disable_padded_drafter_batch:
            raise NotImplementedError(
                "Speculative Decoding with draft models or parallel drafting only "
                "supports padded drafter batch. Please unset "
                "disable_padded_drafter_batch in the speculative_config."
            )

    def _raise_if_multimodal(self):
        if self.supports_mm_inputs:
            raise NotImplementedError(
                "Speculative Decoding with draft models or parallel drafting "
                "does not support multimodal models yet"
            )

    def _raise_if_mrope(self):
        if self.draft_model_config.uses_mrope:
            raise NotImplementedError(
                "Speculative Decoding with draft models or parallel drafting "
                "does not support M-RoPE yet"
            )

    def _init_parallel_drafting_params(self):
        # For parallel drafting, we need the token ID to use for masked slots
        # And for EAGLE + parallel drafting, we need the hidden state tensor to use
        # for those masked slots.

        model_hf_config = self.draft_model_config.hf_config
        if hasattr(model_hf_config, "pard_token"):
            self.parallel_drafting_token_id = model_hf_config.pard_token
        elif hasattr(model_hf_config, "ptd_token_id"):
            self.parallel_drafting_token_id = model_hf_config.ptd_token_id
        else:
            raise ValueError(
                "For parallel drafting, the draft model config must have "
                "`pard_token` or `ptd_token_id` specified in its config.json."
            )

        if self.pass_hidden_states_to_model:
            self.parallel_drafting_hidden_state_tensor = torch.empty(
                self.hidden_size, dtype=self.dtype, device=self.device
            )

317
318
319
    def _get_positions(self, num_tokens: int):
        if self.uses_mrope:
            return self.mrope_positions[:, :num_tokens]
320
321
        if self.uses_xdrope_dim > 0 and self.draft_uses_xdrope_dim > 0:
            return self.xdrope_positions[:, :num_tokens]
322
323
324
325
326
        return self.positions[:num_tokens]

    def _set_positions(self, num_tokens: int, positions: torch.Tensor):
        if self.uses_mrope:
            self.mrope_positions[:, :num_tokens] = positions
327
328
        elif self.uses_xdrope_dim > 0 and self.draft_uses_xdrope_dim > 0:
            self.xdrope_positions[:, :num_tokens] = positions
329
        else:
330
331
332
333
334
            # Convert M-RoPE positions if target model uses M-RoPE
            # but draft doesn't, For text inputs, all M-RoPE
            # dimensions are identical
            if self.vllm_config.model_config.uses_mrope:
                positions = positions[0]
335
336
            self.positions[:num_tokens] = positions

337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
    def _get_slot_mapping(
        self,
        num_tokens: int,
        slot_mapping: torch.Tensor | None = None,
    ) -> dict[str, torch.Tensor]:
        """Return slot_mapping dict for EAGLE layers.

        If slot_mapping is provided, copies it into the buffer first.
        """
        if slot_mapping is not None:
            num_actual = slot_mapping.shape[0]
            self._slot_mapping_buffer[:num_actual].copy_(slot_mapping)
            if num_tokens > num_actual:
                self._slot_mapping_buffer[num_actual:num_tokens].fill_(PADDING_SLOT_ID)

        view = self._slot_mapping_buffer[:num_tokens]
        return {name: view for name in self.attn_layer_names + self.indexer_layer_names}

355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
    def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode) -> None:
        """Initialize cudagraph dispatcher keys for eagle.

        Eagle only supports PIECEWISE cudagraphs (via mixed_mode).
        This should be called after adjust_cudagraph_sizes_for_spec_decode.
        """
        if (
            not self.speculative_config.enforce_eager
            and cudagraph_mode.mixed_mode()
            in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL]
        ):
            eagle_cudagraph_mode = CUDAGraphMode.PIECEWISE
        else:
            eagle_cudagraph_mode = CUDAGraphMode.NONE

        self.cudagraph_dispatcher.initialize_cudagraph_keys(eagle_cudagraph_mode)

372
373
374
375
    def propose(
        self,
        # [num_tokens]
        target_token_ids: torch.Tensor,
376
        # [num_tokens] or [3, num_tokens] when M-RoPE is enabled
377
378
379
380
381
        target_positions: torch.Tensor,
        # [num_tokens, hidden_size]
        target_hidden_states: torch.Tensor,
        # [batch_size]
        next_token_ids: torch.Tensor,
382
        token_indices_to_sample: torch.Tensor | None,
383
        common_attn_metadata: CommonAttentionMetadata,
384
        sampling_metadata: SamplingMetadata,
385
        mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None,
386
        num_rejected_tokens_gpu: torch.Tensor | None = None,
387
388
389
        slot_mappings: dict[str, torch.Tensor]
        | list[dict[str, torch.Tensor]]
        | None = None,
390
    ) -> torch.Tensor:
391
        batch_size = common_attn_metadata.batch_size()
392

393
394
395
        if self.method == "eagle3":
            assert isinstance(self.model, Eagle3LlamaForCausalLM)
            target_hidden_states = self.model.combine_hidden_states(
396
397
                target_hidden_states
            )
398
            assert target_hidden_states.shape[-1] == self.hidden_size
399

400
        num_tokens, token_indices_to_sample, common_attn_metadata = (
401
402
403
404
            self.set_inputs_first_pass(
                target_token_ids=target_token_ids,
                next_token_ids=next_token_ids,
                target_positions=target_positions,
405
406
                target_hidden_states=target_hidden_states,
                token_indices_to_sample=token_indices_to_sample,
407
408
409
410
                cad=common_attn_metadata,
                num_rejected_tokens_gpu=num_rejected_tokens_gpu,
            )
        )
411

412
        assert self.runner is not None
Jiayi Yao's avatar
Jiayi Yao committed
413

414
415
416
417
418
        if self.attn_metadata_builder is None:
            attn_metadata_builder = self._get_attention_metadata_builder()
        else:
            attn_metadata_builder = self.attn_metadata_builder

419
        attn_metadata = attn_metadata_builder.build_for_drafting(
420
421
            common_attn_metadata=common_attn_metadata, draft_index=0
        )
422
423
424
425
426
427
        # FIXME: support hybrid kv for draft model (remove separate indexer)
        if self.draft_indexer_metadata_builder:
            draft_indexer_metadata = (
                self.draft_indexer_metadata_builder.build_for_drafting(
                    common_attn_metadata=common_attn_metadata,
                    draft_index=0,
428
429
                )
            )
430
431
        else:
            draft_indexer_metadata = None
432
433
434
435
436
        # At this moment, we assume all eagle layers belong to the same KV
        # cache group, thus using the same attention metadata.
        per_layer_attn_metadata = {}
        for layer_name in self.attn_layer_names:
            per_layer_attn_metadata[layer_name] = attn_metadata
437

438
439
440
441
        for layer_name in self.indexer_layer_names:
            assert draft_indexer_metadata is not None
            per_layer_attn_metadata[layer_name] = draft_indexer_metadata

Rémi Delacourt's avatar
Rémi Delacourt committed
442
        num_tokens_dp_padded, num_tokens_across_dp = self._pad_batch_across_dp(
443
            num_tokens_unpadded=num_tokens, num_tokens_padded=num_tokens
Rémi Delacourt's avatar
Rémi Delacourt committed
444
445
        )

446
447
448
449
        cudagraph_runtime_mode, batch_desc = self.cudagraph_dispatcher.dispatch(
            num_tokens_dp_padded
        )
        num_input_tokens = batch_desc.num_tokens
Rémi Delacourt's avatar
Rémi Delacourt committed
450
451
452
        if num_tokens_across_dp is not None:
            num_tokens_across_dp[self.dp_rank] = num_input_tokens

453
454
455
        if self.supports_mm_inputs:
            mm_embeds, is_mm_embed = mm_embed_inputs or (None, None)

456
            self.inputs_embeds[:num_tokens] = self.model.embed_input_ids(
457
458
459
                self.input_ids[:num_tokens],
                multimodal_embeddings=mm_embeds,
                is_multimodal=is_mm_embed,
460
            )
461

462
            input_ids = None
463
            inputs_embeds = self.inputs_embeds[:num_input_tokens]
464
465
        else:
            input_ids = self.input_ids[:num_input_tokens]
466
            inputs_embeds = None
467

468
469
470
471
472
473
474
475
        model_kwargs = {
            "input_ids": input_ids,
            "positions": self._get_positions(num_input_tokens),
            "inputs_embeds": inputs_embeds,
        }
        if self.pass_hidden_states_to_model:
            model_kwargs["hidden_states"] = self.hidden_states[:num_input_tokens]

476
        with set_forward_context(
477
478
479
            per_layer_attn_metadata,
            self.vllm_config,
            num_tokens=num_input_tokens,
Rémi Delacourt's avatar
Rémi Delacourt committed
480
            num_tokens_across_dp=num_tokens_across_dp,
481
            cudagraph_runtime_mode=cudagraph_runtime_mode,
482
483
484
            slot_mapping=self._get_slot_mapping(
                num_input_tokens, common_attn_metadata.slot_mapping
            ),
485
        ):
486
487
            ret_hidden_states = self.model(**model_kwargs)
            if not self.model_returns_tuple():
Jiayi Yao's avatar
Jiayi Yao committed
488
                last_hidden_states = ret_hidden_states
489
                hidden_states = last_hidden_states
Jiayi Yao's avatar
Jiayi Yao committed
490
491
            else:
                last_hidden_states, hidden_states = ret_hidden_states
492

493
        sample_hidden_states = last_hidden_states[token_indices_to_sample]
494
        logits = self.model.compute_logits(sample_hidden_states)
495
496

        # Early exit if there is only one draft token to be generated.
497
        if self.num_speculative_tokens == 1 or self.parallel_drafting:
498
            draft_token_ids = logits.argmax(dim=-1)
499
            return draft_token_ids.view(-1, self.num_speculative_tokens)
500

501
        if self.uses_mrope:
502
            positions = self.mrope_positions[:, token_indices_to_sample]
503
        else:
504
            positions = self.positions[token_indices_to_sample]
505
506
507
508
509
510
        if self.method in (
            "deepseek_mtp",
            "ernie_mtp",
            "longcat_flash_mtp",
            "pangu_ultra_moe_mtp",
        ):
511
            hidden_states = self.hidden_states[token_indices_to_sample]
XuruiYang's avatar
XuruiYang committed
512
        else:
513
            hidden_states = hidden_states[token_indices_to_sample]
514
515
516

        if isinstance(attn_metadata, TreeAttentionMetadata):
            # Draft using tree attention.
517
518
519
520
521
522
            draft_token_ids_list = self.propose_tree(
                batch_size=batch_size,
                logits=logits,
                positions=positions,
                hidden_states=hidden_states,
                common_attn_metadata=common_attn_metadata,
523
                slot_mappings=slot_mappings,
524
525
526
527
            )
            # [batch_size, num_tree_tokens]
            return torch.cat(draft_token_ids_list, dim=1)

528
        draft_token_ids = logits.argmax(dim=-1)
529

530
531
532
        if self.allowed_attn_types is not None and not isinstance(
            attn_metadata, self.allowed_attn_types
        ):
533
534
535
536
            raise ValueError(
                f"Unsupported attention metadata type for speculative "
                "decoding with num_speculative_tokens > 1: "
                f"{type(attn_metadata)}. Supported types are: "
537
538
                f"{self.allowed_attn_types}"
            )
539

540
541
542
        # Generate the remaining draft tokens.
        draft_token_ids_list = [draft_token_ids]

Rémi Delacourt's avatar
Rémi Delacourt committed
543
        batch_size_dp_padded, batch_size_across_dp = self._pad_batch_across_dp(
544
            num_tokens_unpadded=batch_size, num_tokens_padded=batch_size
Rémi Delacourt's avatar
Rémi Delacourt committed
545
546
        )

547
548
549
550
        cudagraph_runtime_mode, batch_desc = self.cudagraph_dispatcher.dispatch(
            batch_size_dp_padded
        )
        input_batch_size = batch_desc.num_tokens
Rémi Delacourt's avatar
Rémi Delacourt committed
551
552
        if batch_size_across_dp is not None:
            batch_size_across_dp[self.dp_rank] = input_batch_size
553
554
555

        common_attn_metadata.num_actual_tokens = batch_size
        common_attn_metadata.max_query_len = 1
556
        common_attn_metadata.query_start_loc = self.arange[: batch_size + 1]
557
        common_attn_metadata.query_start_loc_cpu = torch.from_numpy(
558
559
            self.token_arange_np[: batch_size + 1]
        ).clone()
560
561
562
563
564
565
566
567
568
569
570

        # In padded drafter batch, we need to adjust the sequence lengths
        # to remove the "padding" (i.e. rejected tokens).
        # Only apply this adjustment when we have rejected tokens
        # (i.e., not the first proposal).
        if self.num_speculative_tokens > 1 and num_rejected_tokens_gpu is not None:
            common_attn_metadata.seq_lens -= num_rejected_tokens_gpu
            # Invalidate the CPU-side shadows to avoid H<>D sync.
            common_attn_metadata._seq_lens_cpu = None
            common_attn_metadata._num_computed_tokens_cpu = None

571
        for token_index in range(self.num_speculative_tokens - 1):
572
            # Update the inputs.
573
574
575
            # cast to int32 is crucial when eagle model is compiled.
            # tensor.argmax() returns int64 by default.
            input_ids = draft_token_ids_list[-1].int()
576
577
578
579
580
581
582
583
584
585
586
587
588
            if self.uses_mrope:
                positions += 1
                # NOTE(woosuk): We should handle the case where the draft model
                # generates tokens beyond the max model length.
                # Since it is complex to remove such requests from the batch,
                # we keep them in the batch but adjust the position ids
                # and slot mappings to avoid the
                # out-of-range access during the model execution.
                # The draft tokens generated with this adjustment
                # should be ignored.
                exceeds_max_model_len = positions[0] >= self.max_model_len
                # Mask out the position ids that exceed the max model length.
                # Otherwise, we may get out-of-range error in RoPE.
589
590
591
592
593
                clamped_positions = torch.where(
                    exceeds_max_model_len.unsqueeze(0),
                    torch.zeros_like(positions),
                    positions,
                )
594
595
596
            else:
                positions += 1
                exceeds_max_model_len = positions >= self.max_model_len
597
                clamped_positions = torch.where(exceeds_max_model_len, 0, positions)
598
599
600
            # For data integrity when async scheduling, we shouldn't use in place
            # operations in case they are modified in next step's `prepare_input`
            # of main model.
601
            # Increment the sequence lengths.
602
            common_attn_metadata.seq_lens += 1
603
604
            # For the requests that exceed the max model length, we set the
            # sequence length to 1 to minimize their overheads in attention.
605
            common_attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1)
606
607
608
609
610
611
            # Increment the maximum sequence length. We increment max_seq_len
            # unconditionally even though some seq_lens may have been capped above,
            # as max_seq_len serves as an upper bound for sequence lengths.
            common_attn_metadata.max_seq_len = min(
                common_attn_metadata.max_seq_len + 1, self.max_model_len
            )
612

613
614
615
616
617
618
            # Also update the CPU-side shadow; NOTE: this is hacky and should be
            # removed in when common_attn_metadata.seq_lens_cpu is deprecated.
            if common_attn_metadata._seq_lens_cpu is not None:
                common_attn_metadata._seq_lens_cpu += 1
            if common_attn_metadata._num_computed_tokens_cpu is not None:
                common_attn_metadata._num_computed_tokens_cpu += 1
619

620
            # Compute the slot mapping.
621
            block_size = attn_metadata_builder.kv_cache_spec.block_size
622
623
            if self.uses_mrope:
                # all dimensions of positions are the same
624
                block_numbers = clamped_positions[0] // block_size
625
            else:
626
                block_numbers = clamped_positions // block_size
627
            block_ids = common_attn_metadata.block_table_tensor.gather(
628
629
                dim=1, index=block_numbers.view(-1, 1)
            )
630
            block_ids = block_ids.view(-1)
631
632
            if self.uses_mrope:
                common_attn_metadata.slot_mapping = (
633
                    block_ids * block_size + clamped_positions[0] % block_size
634
                )
635
636
            else:
                common_attn_metadata.slot_mapping = (
637
                    block_ids * block_size + clamped_positions % block_size
638
                )
639
640
641
            # Mask out the slot mappings that exceed the max model length.
            # Otherwise, the KV cache will be inadvertently updated with the
            # padding tokens.
642
            common_attn_metadata.slot_mapping.masked_fill_(
643
644
                exceeds_max_model_len, PADDING_SLOT_ID
            )
645
646

            # Rebuild attention metadata
647
            attn_metadata = attn_metadata_builder.build_for_drafting(  # type: ignore
648
649
                common_attn_metadata=common_attn_metadata, draft_index=token_index + 1
            )
650
651
            for layer_name in self.attn_layer_names:
                per_layer_attn_metadata[layer_name] = attn_metadata
652

653
654
            # copy inputs to buffer for cudagraph
            self.input_ids[:batch_size] = input_ids
655
            self._set_positions(batch_size, clamped_positions)
656
            self.hidden_states[:batch_size] = hidden_states
657
            if self.supports_mm_inputs:
658
                self.inputs_embeds[:batch_size] = self.model.embed_input_ids(input_ids)
659

660
                input_ids = None
661
                inputs_embeds = self.inputs_embeds[:input_batch_size]
662
663
            else:
                input_ids = self.input_ids[:input_batch_size]
664
                inputs_embeds = None
665

666
            # Run the model.
667
668
669
670
671
672
673
674
            model_kwargs = {
                "input_ids": input_ids,
                "positions": self._get_positions(input_batch_size),
                "inputs_embeds": inputs_embeds,
            }
            if self.pass_hidden_states_to_model:
                model_kwargs["hidden_states"] = self.hidden_states[:input_batch_size]

675
            with set_forward_context(
676
677
678
                per_layer_attn_metadata,
                self.vllm_config,
                num_tokens=input_batch_size,
Rémi Delacourt's avatar
Rémi Delacourt committed
679
                num_tokens_across_dp=batch_size_across_dp,
680
                cudagraph_runtime_mode=cudagraph_runtime_mode,
681
682
683
                slot_mapping=self._get_slot_mapping(
                    input_batch_size, common_attn_metadata.slot_mapping
                ),
684
            ):
685
686
                ret_hidden_states = self.model(**model_kwargs)
                if not self.model_returns_tuple():
687
688
689
690
                    last_hidden_states = ret_hidden_states
                    hidden_states = ret_hidden_states
                else:
                    last_hidden_states, hidden_states = ret_hidden_states
691

692
            hidden_states = hidden_states[:batch_size]
693
            logits = self.model.compute_logits(last_hidden_states[:batch_size])
694
            draft_token_ids = logits.argmax(dim=-1)
695
696
697
698
            draft_token_ids_list.append(draft_token_ids)

        # [batch_size, num_speculative_tokens]
        draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
699
        return draft_token_ids
700

701
702
703
704
705
    def set_inputs_first_pass(
        self,
        target_token_ids: torch.Tensor,
        next_token_ids: torch.Tensor,
        target_positions: torch.Tensor,
706
707
        target_hidden_states: torch.Tensor,
        token_indices_to_sample: torch.Tensor | None,
708
709
710
        cad: CommonAttentionMetadata,
        num_rejected_tokens_gpu: torch.Tensor | None,
    ) -> tuple[int, torch.Tensor, CommonAttentionMetadata]:
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
        if not self.needs_extra_input_slots:
            # Default EAGLE pathway: no reshaping of input tensors needed.
            # Simply rotate the input ids and leave the positions unchanged,
            # Inserting the next token ids at the last slot in each request.
            if token_indices_to_sample is None:
                token_indices_to_sample = cad.query_start_loc[1:] - 1

            num_tokens = target_token_ids.shape[0]
            # Shift the input ids by one token.
            # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
            self.input_ids[: num_tokens - 1] = target_token_ids[1:]
            # Replace the last token with the next token.
            # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
            self.input_ids[token_indices_to_sample] = next_token_ids

            # copy inputs to buffer for cudagraph
            if self.uses_xdrope_dim > 0 and self.draft_uses_xdrope_dim == 0:
                target_positions = target_positions[0]
            self._set_positions(num_tokens, target_positions)

            self.hidden_states[:num_tokens] = target_hidden_states

            return num_tokens, token_indices_to_sample, cad
        else:
            assert self.is_rejected_token_mask is not None
            assert self.is_masked_token_mask is not None
            # 1.
            # Call a custom triton kernel to copy input_ids and positions
            # into the correct slots in the preallocated buffers self.input_ids,
            # self.positions.
            batch_size = cad.batch_size()
            # Since we might have to copy a lot of data for prefills, we select the
            # block size based on the max query length and limit to max 256 slots/block.
            max_num_tokens_per_request = (
                cad.max_query_len + self.net_num_new_slots_per_request
            )
            BLOCK_SIZE_TOKENS = min(
                256, triton.next_power_of_2(max_num_tokens_per_request)
            )
            num_blocks = (
                max_num_tokens_per_request + BLOCK_SIZE_TOKENS - 1
            ) // BLOCK_SIZE_TOKENS
            total_num_input_tokens = target_token_ids.shape[0]
            total_num_output_tokens = total_num_input_tokens + (
                self.net_num_new_slots_per_request * batch_size
            )
757

758
759
760
761
762
            token_indices_to_sample = torch.empty(
                batch_size * self.extra_slots_per_request,
                dtype=torch.int32,
                device=self.device,
            )
763

764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
            # Destination indices to write target_hidden_states into drafting buffer.
            out_hidden_state_mapping = torch.empty(
                total_num_input_tokens, dtype=torch.int32, device=self.device
            )

            # Kernel grid: one program per request (row)
            grid = (batch_size, num_blocks)
            query_start_loc = cad.query_start_loc
            query_end_loc = cad.query_start_loc[1:] - 1
            if num_rejected_tokens_gpu is not None:
                query_end_loc = query_end_loc - num_rejected_tokens_gpu
            copy_and_expand_eagle_inputs_kernel[grid](
                # (Padded) Inputs from the target model
                target_token_ids_ptr=target_token_ids,
                target_positions_ptr=target_positions,
                next_token_ids_ptr=next_token_ids,  # sampled tokens, one per request
                # Outputs to the drafting buffers
                out_input_ids_ptr=self.input_ids,
                out_positions_ptr=self.positions,  # Doesn't support mrope for now
                out_is_rejected_token_mask_ptr=self.is_rejected_token_mask,
                out_is_masked_token_mask_ptr=self.is_masked_token_mask,
                out_new_token_indices_ptr=token_indices_to_sample,
                out_hidden_state_mapping_ptr=out_hidden_state_mapping,
                # Input metadata
                query_start_loc_ptr=query_start_loc,
                query_end_loc_ptr=query_end_loc,
                padding_token_id=0,
                parallel_drafting_token_id=self.parallel_drafting_token_id,
                # Sizing info
                # Note that we can deduce batch_size for free from the grid size
                total_input_tokens=total_num_input_tokens,
                num_padding_slots_per_request=self.extra_slots_per_request,
                shift_input_ids=self.pass_hidden_states_to_model,
                BLOCK_SIZE_TOKENS=BLOCK_SIZE_TOKENS,
            )
            if self.pass_hidden_states_to_model:
                assert self.parallel_drafting_hidden_state_tensor is not None
                self.hidden_states[out_hidden_state_mapping] = target_hidden_states
                # Use torch.where to avoid DtoH sync from boolean indexing
                mask = self.is_masked_token_mask[:total_num_output_tokens]
                torch.where(
                    mask.unsqueeze(1),
                    self.parallel_drafting_hidden_state_tensor,
                    self.hidden_states[:total_num_output_tokens],
                    out=self.hidden_states[:total_num_output_tokens],
                )

            # 2.
            # Recompute the slot mapping based on the new positions and
            # rejection mask.
            builder = (
                self._get_attention_metadata_builder()
                if self.attn_metadata_builder is None
                else self.attn_metadata_builder
            )
            new_slot_mapping = compute_new_slot_mapping(
                cad=cad,
                new_positions=self.positions[:total_num_output_tokens],
                is_rejected_token_mask=self.is_rejected_token_mask[
                    :total_num_output_tokens
                ],
                block_size=builder.kv_cache_spec.block_size,
                num_new_tokens=self.net_num_new_slots_per_request,
                max_model_len=self.max_model_len,
            )

            # 3. Update the common attention metadata with the new (meta)data
            new_cad = extend_all_queries_by_N(
                cad,
                N=self.net_num_new_slots_per_request,
                arange=self.arange,
                new_slot_mapping=new_slot_mapping,
            )
837

838
            return total_num_output_tokens, token_indices_to_sample, new_cad
839
840
841
842

    def model_returns_tuple(self) -> bool:
        return self.method not in ("mtp", "draft_model")

843
    def prepare_next_token_ids_cpu(
844
        self,
845
        sampled_token_ids: list[list[int]],
846
847
848
849
        requests: dict[str, CachedRequestState],
        gpu_input_batch: InputBatch,
        num_scheduled_tokens: dict[str, int],
    ) -> torch.Tensor:
850
851
852
853
854
855
856
857
858
859
        """
        This function is used to prepare the inputs for speculative decoding.
        It calculates the next token ids for each request based on the sampled
        token ids from the CPU. If a request has no sampled token ids (e.g.,
        during the initial decoding steps), it falls back to using the request
        state to get the next token id.
        """
        req_ids = gpu_input_batch.req_ids
        next_token_ids: list[int] = []
        for i, token_ids in enumerate(sampled_token_ids):
860
            if token_ids:
861
862
863
864
865
866
867
                # Common case.
                next_token_id = token_ids[-1]
            else:
                # Partial prefill (rare case).
                # Get the next token id from the request state.
                req_id = req_ids[i]
                req_state = requests[req_id]
868
                seq_len = req_state.num_computed_tokens + num_scheduled_tokens[req_id]
869
870
                next_token_id = req_state.get_token_id(seq_len)
            next_token_ids.append(next_token_id)
871
        next_token_ids = torch.tensor(
872
873
            next_token_ids, dtype=torch.int32, device=self.input_ids.device
        )
874
        return next_token_ids
875

876
877
878
879
880
881
    def prepare_next_token_ids_padded(
        self,
        common_attn_metadata: CommonAttentionMetadata,
        sampled_token_ids: torch.Tensor,
        requests: dict[str, CachedRequestState],
        gpu_input_batch: InputBatch,
882
        discard_request_mask: torch.Tensor,
883
    ) -> tuple[torch.Tensor, torch.Tensor]:
884
885
886
887
        """
        This function is used to prepare the inputs for speculative decoding.
        It calculates the next token ids and the number of valid sampled tokens
        for each request, considering the "discarded" requests whose next token
888
889
        is not sampled and comes from `request.get_token_id()` instead. This is denoted
        the "backup" token id. It also counts rejected tokens via `sampled_token_ids`.
890
891
892
        """
        # Precompute get_token_id for when there is no valid next token
        num_reqs = gpu_input_batch.num_reqs
893
894
895
896
897
898
        self.backup_next_token_ids.np[:num_reqs] = np.array(
            [
                requests[gpu_input_batch.req_ids[i]].get_token_id(
                    common_attn_metadata.seq_lens_cpu[i].item()
                )
                for i in range(num_reqs)
899
900
            ],
            dtype=np.int32,
901
        )
902
        self.backup_next_token_ids.copy_to_gpu(num_reqs)
903
        backup_tokens_gpu = self.backup_next_token_ids.gpu
904

905
906
        batch_size, num_tokens = sampled_token_ids.shape
        device = sampled_token_ids.device
907

908
909
        assert discard_request_mask.dtype == torch.bool
        assert backup_tokens_gpu.dtype == torch.int32
910

911
912
        next_token_ids = torch.empty(batch_size, dtype=torch.int32, device=device)
        valid_sampled_tokens_count = next_token_ids.new_empty(batch_size)
913

914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
        # Kernel grid: one program per request (row)
        grid = (batch_size,)

        # Find the next power of 2 for block sizes
        BLOCK_SIZE_TOKENS = triton.next_power_of_2(num_tokens)
        eagle_prepare_next_token_padded_kernel[grid](
            sampled_token_ids,
            discard_request_mask,
            backup_tokens_gpu,
            next_token_ids,
            valid_sampled_tokens_count,
            gpu_input_batch.vocab_size,
            num_tokens,
            batch_size,
            sampled_token_ids.stride(0),
            BLOCK_SIZE_TOKENS=BLOCK_SIZE_TOKENS,
930
        )
931
932
933

        return next_token_ids, valid_sampled_tokens_count

934
935
936
937
938
    def prepare_inputs_padded(
        self,
        common_attn_metadata: CommonAttentionMetadata,
        spec_decode_metadata: SpecDecodeMetadata,
        valid_sampled_tokens_count: torch.Tensor,
939
    ) -> tuple[CommonAttentionMetadata, torch.Tensor, torch.Tensor]:
940
941
942
943
944
945
        """
        This function is used to prepare the inputs for speculative decoding
        It updates the common_attn_metadata for speculative decoding,
        but does not consider the rejected tokens. Instead, all tokens
        are included as inputs to the speculator, with the rejected tokens
        used as padding and filtered out later by `token_indices_to_sample`.
946
        No blocking CPU operations should be introduced in this function.
947
        """
948
949
950
951
952
        num_reqs = common_attn_metadata.num_reqs
        device = valid_sampled_tokens_count.device

        token_indices_to_sample = torch.empty(
            (num_reqs,), dtype=torch.int32, device=device
953
        )
954
955
956
        num_rejected_tokens_gpu = torch.empty(
            (num_reqs,), dtype=torch.int32, device=device
        )
957

958
959
960
961
962
963
        grid = (num_reqs,)
        eagle_prepare_inputs_padded_kernel[grid](
            spec_decode_metadata.cu_num_draft_tokens,
            valid_sampled_tokens_count,
            common_attn_metadata.query_start_loc,
            token_indices_to_sample,
964
            num_rejected_tokens_gpu,
965
            num_reqs,
966
        )
967
968

        query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
969
        new_query_len_per_req = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
970
971
972
973
974
975
976

        total_num_tokens = query_start_loc_cpu[-1].item()

        spec_common_attn_metadata = CommonAttentionMetadata(
            query_start_loc=common_attn_metadata.query_start_loc,
            seq_lens=common_attn_metadata.seq_lens,
            query_start_loc_cpu=query_start_loc_cpu,
977
978
            _seq_lens_cpu=common_attn_metadata._seq_lens_cpu,
            _num_computed_tokens_cpu=common_attn_metadata._num_computed_tokens_cpu,
979
980
981
982
983
            num_reqs=common_attn_metadata.num_reqs,
            num_actual_tokens=total_num_tokens,
            max_query_len=new_query_len_per_req.max().item(),
            max_seq_len=common_attn_metadata.seq_lens_cpu.max().item(),
            block_table_tensor=common_attn_metadata.block_table_tensor,
984
            slot_mapping=common_attn_metadata.slot_mapping[:total_num_tokens],
985
            causal=True,
986
            dcp_local_seq_lens=common_attn_metadata.dcp_local_seq_lens,
987
988
        )

989
990
991
992
993
        return (
            spec_common_attn_metadata,
            token_indices_to_sample,
            num_rejected_tokens_gpu,
        )
994

995
996
997
998
999
1000
1001
1002
1003
1004
    def propose_tree(
        self,
        batch_size: int,
        # [num_tokens, vocab_size]
        logits: torch.Tensor,
        # [num_tokens]
        positions: torch.Tensor,
        # [num_tokens, hidden_size]
        hidden_states: torch.Tensor,
        common_attn_metadata: CommonAttentionMetadata,
1005
1006
1007
        slot_mappings: dict[str, torch.Tensor]
        | list[dict[str, torch.Tensor]]
        | None = None,
1008
    ) -> list[torch.Tensor]:
1009
1010
1011
1012
        tree_attn_metadata_builder = self.runner.attn_groups[0][
            0
        ].get_metadata_builder()
        assert isinstance(tree_attn_metadata_builder, TreeAttentionMetadataBuilder)
1013

1014
        total_num_drafts = self.cu_drafts_per_level[0]
1015
1016
        level_num_drafts = total_num_drafts
        # Sample a draft token for each child at the tree root level.
1017
        num_children = self.child_drafts_per_level[0]
1018
1019
1020
        if num_children == 1:
            draft_token_ids = logits.argmax(dim=-1).view(batch_size, -1)
        else:
1021
1022
1023
            draft_token_ids = torch.topk(logits, num_children, dim=-1).indices.view(
                batch_size, -1
            )
1024
1025
1026
1027
        draft_token_ids_list = [draft_token_ids]
        draft_hidden_states = hidden_states.view(batch_size, 1, -1)

        # Initialize empty tensors for concatenation with the level outputs.
1028
1029
1030
1031
1032
1033
1034
1035
1036
        tree_input_ids = torch.empty(
            0, device=self.input_ids.device, dtype=self.input_ids.dtype
        )
        tree_positions = torch.empty(
            0, device=self.positions.device, dtype=self.positions.dtype
        )
        tree_hidden_states = torch.empty(
            0, device=self.hidden_states.device, dtype=self.hidden_states.dtype
        )
1037
1038
        # Precompute the draft token positions.
        flattened_draft_positions = (
1039
1040
            positions.view(batch_size, -1) + self.tree_draft_pos_offsets[:batch_size, :]
        )
1041
        tree_depth = len(self.cu_drafts_per_level)
1042
        for level in range(tree_depth - 1):
1043
1044
            # Get draft positions for RoPE.
            draft_positions = positions + (level + 1)
1045
            exceeds_max_model_len = (positions + total_num_drafts) >= self.max_model_len
1046
1047
            # Mask out the position ids that exceed the max model length.
            # Otherwise, we may get out-of-range error in RoPE.
1048
            draft_positions = torch.where(
1049
1050
1051
                exceeds_max_model_len,
                0,
                draft_positions,
1052
1053
            ).view(batch_size, -1)

1054
1055
            if level_num_drafts > 1:
                # Repeat the positions for each draft at this level.
1056
                draft_positions = draft_positions.repeat_interleave(
1057
1058
                    level_num_drafts, dim=1
                )
1059
1060
1061
1062

            if num_children > 1:
                # Repeat draft hidden states for each child.
                draft_hidden_states = draft_hidden_states.repeat_interleave(
1063
1064
                    num_children, dim=1
                )
1065
1066

            # Concatenate the draft tokens, positions, and hidden states.
1067
1068
            tree_input_ids = torch.cat([tree_input_ids, draft_token_ids], dim=1)
            tree_positions = torch.cat([tree_positions, draft_positions], dim=1)
1069
            tree_hidden_states = torch.cat(
1070
1071
                [tree_hidden_states, draft_hidden_states], dim=1
            )
1072
1073
1074

            # Build new attention metadata for the next level of drafts.
            # This is necessary to support tree attention.
1075
            query_len = total_num_drafts
1076
1077
            common_attn_metadata = replace(
                common_attn_metadata,
1078
                query_start_loc=query_len * self.arange[: batch_size + 1],
1079
1080
1081
1082
1083
                seq_lens=common_attn_metadata.seq_lens + level_num_drafts,
                num_actual_tokens=batch_size * query_len,
                max_query_len=query_len,
            )
            attn_metadata = tree_attn_metadata_builder.build_for_drafting(
1084
                common_attn_metadata=common_attn_metadata, draft_index=level + 1
1085
1086
1087
1088
1089
1090
1091
1092
            )

            # Apply new attention metadata to all layers.
            per_layer_attn_metadata = {}
            for layer_name in self.attn_layer_names:
                per_layer_attn_metadata[layer_name] = attn_metadata

            # Consider max model length.
1093
1094
1095
            attn_metadata.max_seq_len = min(
                attn_metadata.max_seq_len, self.max_model_len
            )
1096
1097
1098
1099
1100
            # For the requests that exceed the max model length, we set the
            # sequence length to 1 to minimize their overheads in attention.
            attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1)

            # Compute the slot mapping.
1101
            block_size = tree_attn_metadata_builder.kv_cache_spec.block_size
1102
            query_positions = flattened_draft_positions[:, level : level + query_len]
1103
            block_numbers = query_positions // block_size
1104
            block_ids = attn_metadata.block_table.gather(dim=1, index=block_numbers)
1105
            slot_mapping = block_ids * block_size + query_positions % block_size
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
            # Mask out the slot mappings that exceed the max model length.
            # Otherwise, the KV cache will be inadvertently updated with the
            # padding tokens.
            slot_mapping[exceeds_max_model_len] = PADDING_SLOT_ID
            attn_metadata.slot_mapping = slot_mapping.view(-1)

            # Copy inputs to buffer for cudagraph.
            num_tokens = attn_metadata.num_actual_tokens
            input_ids = tree_input_ids.view(-1)
            self.input_ids[:num_tokens] = input_ids
            self.positions[:num_tokens] = tree_positions.view(-1)
1117
            self.hidden_states[:num_tokens] = tree_hidden_states.view(num_tokens, -1)
1118

1119
1120
1121
1122
            cudagraph_runtime_mode, batch_desc = self.cudagraph_dispatcher.dispatch(
                num_tokens
            )
            num_input_tokens = batch_desc.num_tokens
1123
            # Run the model.
1124
            with set_forward_context(
1125
1126
1127
1128
                per_layer_attn_metadata,
                self.vllm_config,
                num_tokens=num_input_tokens,
                cudagraph_runtime_mode=cudagraph_runtime_mode,
1129
1130
1131
                slot_mapping=self._get_slot_mapping(
                    num_input_tokens, attn_metadata.slot_mapping
                ),
1132
            ):
1133
1134
1135
1136
1137
1138
1139
1140
1141
                last_hidden_states, hidden_states = self.model(
                    input_ids=self.input_ids[:num_input_tokens],
                    positions=self.positions[:num_input_tokens],
                    hidden_states=self.hidden_states[:num_input_tokens],
                    inputs_embeds=None,
                )

            # Get the output hidden states for the draft tokens.
            draft_hidden_states = hidden_states[:num_tokens].view(
1142
1143
                batch_size, query_len, -1
            )[:, -level_num_drafts:]
1144
            draft_last_hidden_states = last_hidden_states[:num_tokens].view(
1145
1146
                batch_size, query_len, -1
            )[:, -level_num_drafts:]
1147
1148
1149

            # Get the output logits for the draft tokens.
            logits = self.model.compute_logits(
1150
1151
                draft_last_hidden_states.reshape(batch_size * level_num_drafts, -1)
            )
1152
1153
1154
1155
1156
1157

            # Sample a draft token for each child at the next tree level.
            num_children = self.child_drafts_per_level[level + 1]
            if num_children == 1:
                draft_token_ids = logits.argmax(dim=-1).view(batch_size, -1)
            else:
1158
1159
1160
                draft_token_ids = torch.topk(logits, num_children, dim=-1).indices.view(
                    batch_size, -1
                )
1161
1162
1163
            draft_token_ids_list.append(draft_token_ids)

            # Update the # drafts counters for the next tree level.
1164
            level_num_drafts = self.cu_drafts_per_level[level + 1] - total_num_drafts
1165
1166
1167
            total_num_drafts = self.cu_drafts_per_level[level + 1]
        return draft_token_ids_list

1168
    def prepare_inputs(
1169
1170
        self,
        common_attn_metadata: CommonAttentionMetadata,
1171
1172
        sampled_token_ids: list[list[int]],
        num_draft_tokens: list[int],
1173
1174
    ) -> tuple[CommonAttentionMetadata, torch.Tensor]:
        """
1175
        This function is used to prepare the inputs for speculative decoding.
1176
1177
1178
1179
1180
1181
        It updates to the common_attn_metadata to account for the rejected
        tokens (and newly sampled tokens). It also returns the token indices
        of the tokens that should be fed to the speculator.
        """
        # E.g.
        #  common_attn_metadata.query_start_loc{_cpu}:
1182
        #       [0, q1, q1 + q2, q1 + q2 + q3]
1183
1184
1185
1186
1187
1188
        #  common_attn_metadata.seq_lens{_cpu}: [s1, s2, s3]
        #  num_rejected_tokens: [n1, n2, n3]
        # This function computes the intermediate values:
        #  num_tokens_per_req: [q1 - n1, q2 - n2, q3 - n3]
        # And returns:
        #  common_attn_metadata.query_start_loc{_cpu}:
1189
        #       [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
1190
        #  common_attn_metadata.seq_lens{_cpu}:
1191
        #       [s1 - n1 + 1, s2 - n2 + 1, s3 - n3 + 1]
1192
        #  token_indices: [0, 1, ..., q1 - n1 - 1,
1193
1194
        #                 q1, q1 + 1, ..., q1 + q2 - n2 - 1,
        #                 q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1]
1195

1196
1197
1198
1199
        num_rejected_tokens = [
            n + 1 - len(sampled_token_ids[i]) if n > 0 else 0
            for i, n in enumerate(num_draft_tokens)
        ]
1200
        num_rejected_tokens = torch.tensor(num_rejected_tokens, dtype=torch.int32)
1201

1202
1203
        device = common_attn_metadata.query_start_loc.device
        query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
1204
        new_seq_lens_cpu = common_attn_metadata.seq_lens_cpu - num_rejected_tokens
1205
1206

        # [0, q1, q1 + q2, q1 + q2 + q3] -> [q1, q2, q3]
1207
        new_query_len_per_req = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
1208
1209
1210
1211
1212
1213
1214
1215
        # [q1, q2, q3] -> [q1 - n1, q2 - n2, q3 - n3]
        new_num_tokens_per_req = new_query_len_per_req - num_rejected_tokens
        new_num_tokens_per_req_np = new_num_tokens_per_req.numpy()

        # [q1 - n1, q2 - n2, q3 - n3] ->
        # [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
        new_query_start_loc_cpu = torch.zeros(
            query_start_loc_cpu.shape,
1216
            dtype=torch.int32,
1217
1218
            pin_memory=is_pin_memory_available(),
        )
1219
1220
1221
1222
1223
1224
1225
1226
1227
        new_query_start_loc_np = new_query_start_loc_cpu.numpy()
        np.cumsum(new_num_tokens_per_req_np, out=new_query_start_loc_np[1:])

        total_num_tokens = new_query_start_loc_np[-1]
        # Example assuming num_tokens_per_req_np = [2, 4, 3]
        # this implies that `new_query_start_locs` is:
        # [0, 2, 6, 9] ->
        # [0, 0, 2, 2, 2, 2, 6, 6, 6]
        #  _r1_  ____r2____  ___r3__
1228
1229
1230
        new_query_start_locs_expanded = np.repeat(
            new_query_start_loc_np[:-1], new_num_tokens_per_req_np
        )
1231
1232
1233
        # [0, 1, 2, 3, 4, 5, 6, 7, 8] ->
        # [0, 1, 0, 1, 2, 3, 0, 1, 2]
        #  _r1_  ____r2____  ___r3__
1234
        token_offsets = (
1235
1236
            self.token_arange_np[:total_num_tokens] - new_query_start_locs_expanded
        )
1237
1238
1239
1240
1241
1242

        # Expand starting positions to match token pattern
        # [0, q1, q1 + q2] ->
        # [0, 0, q1, q1, q1, q1, q1 + q2, q1 + q2, q1 + q2]
        #  _r1_  _____r2_______  ___________r3____________
        old_query_start_locs_expanded = np.repeat(
1243
1244
            query_start_loc_cpu[:-1].numpy(), new_num_tokens_per_req_np
        )
1245
        # Final token indices are:
1246
1247
1248
        # [0, 1,                                // req 1
        #  q1 + 0, q1 + 1, q1 + 2, q1 + 3,       // req 2
        #  q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3
1249
        token_indices_np = token_offsets + old_query_start_locs_expanded
1250
        token_indices = torch.from_numpy(token_indices_np).to(device, non_blocking=True)
1251
1252

        spec_common_attn_metadata = CommonAttentionMetadata(
1253
            query_start_loc=new_query_start_loc_cpu.to(device, non_blocking=True),
1254
1255
            seq_lens=new_seq_lens_cpu.to(device, non_blocking=True),
            query_start_loc_cpu=new_query_start_loc_cpu,
1256
1257
            _seq_lens_cpu=new_seq_lens_cpu,
            _num_computed_tokens_cpu=common_attn_metadata._num_computed_tokens_cpu,
1258
1259
1260
            num_reqs=common_attn_metadata.num_reqs,
            num_actual_tokens=total_num_tokens,
            max_query_len=new_query_len_per_req.max().item(),
1261
            max_seq_len=new_seq_lens_cpu.max().item(),
1262
1263
            block_table_tensor=common_attn_metadata.block_table_tensor,
            slot_mapping=common_attn_metadata.slot_mapping[token_indices],
1264
            causal=True,
1265
            dcp_local_seq_lens=common_attn_metadata.dcp_local_seq_lens,
1266
        )
1267
1268

        return spec_common_attn_metadata, token_indices
1269

1270
    def get_model_name(self, model: nn.Module) -> str:
1271
        if hasattr(model, "module"):  # multi-GPU
1272
1273
1274
            model = model.module
        return model.__class__.__name__

1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
    def _get_model(self) -> nn.Module:
        """
        Default method to call get_model(). Can be overridden by subclasses which
        need to customize model loading.
        """
        from vllm.compilation.backends import set_model_tag

        with set_model_tag("eagle_head"):
            model = get_model(
                vllm_config=self.vllm_config,
                model_config=self.speculative_config.draft_model_config,
1286
                load_config=self.speculative_config.draft_load_config,
1287
1288
1289
            )
        return model

1290
    def load_model(self, target_model: nn.Module) -> None:
1291
        target_attn_layer_names = set(
1292
1293
1294
1295
            get_layers_from_vllm_config(
                self.vllm_config,
                AttentionLayerBase,  # type: ignore[type-abstract]
            ).keys()
1296
        )
1297
1298
        # FIXME: support hybrid kv for draft model
        target_indexer_layer_names = set(
1299
1300
1301
1302
            get_layers_from_vllm_config(
                self.vllm_config, DeepseekV32IndexerCache
            ).keys()
        )
1303

1304
        self.model = self._get_model()
1305

1306
        draft_attn_layer_names = (
1307
1308
1309
1310
            get_layers_from_vllm_config(
                self.vllm_config,
                AttentionLayerBase,  # type: ignore[type-abstract]
            ).keys()
1311
1312
1313
1314
1315
1316
            - target_attn_layer_names
        )
        indexer_layers = get_layers_from_vllm_config(
            self.vllm_config, DeepseekV32IndexerCache
        )
        draft_indexer_layer_names = indexer_layers.keys() - target_indexer_layer_names
1317
        self.attn_layer_names = list(draft_attn_layer_names - draft_indexer_layer_names)
1318
1319
1320
1321
1322
        self.indexer_layer_names = list(draft_indexer_layer_names)

        if self.indexer_layer_names:
            first_layer = self.indexer_layer_names[0]
            self.draft_indexer_metadata_builder = (
1323
1324
1325
                indexer_layers[first_layer]
                .get_attn_backend()
                .get_builder_cls()(
1326
                    indexer_layers[first_layer].get_kv_cache_spec(self.vllm_config),
1327
1328
1329
                    self.indexer_layer_names,
                    self.vllm_config,
                    self.device,
1330
1331
                )
            )
1332
1333
        else:
            self.draft_indexer_metadata_builder = None
1334

1335
        if self.supports_mm_inputs:
1336
1337
1338
            # Even if the target model is multimodal, we can also use
            # text-only draft models
            try:
1339
                dummy_input_ids = torch.tensor([[1]], device=self.input_ids.device)
1340
                self.model.embed_input_ids(dummy_input_ids, multimodal_embeddings=None)
1341
1342
1343
            except (NotImplementedError, AttributeError, TypeError):
                logger.warning(
                    "Draft model does not support multimodal inputs, "
1344
1345
                    "falling back to text-only mode"
                )
1346
                self.supports_mm_inputs = False
1347

1348
1349
        if supports_multimodal(target_model):
            # handle multimodality
1350
            assert hasattr(target_model, "config")
1351
1352
1353
            if self.get_model_name(target_model) in [
                "Qwen2_5_VLForConditionalGeneration",
                "Qwen3VLForConditionalGeneration",
1354
                "Qwen3VLMoeForConditionalGeneration",
1355
                "HunYuanVLForConditionalGeneration",
1356
                "GlmOcrForConditionalGeneration",
1357
1358
                "Qwen3_5ForConditionalGeneration",
                "Qwen3_5MoeForConditionalGeneration",
1359
            ]:
1360
                self.model.config.image_token_index = target_model.config.image_token_id
1361
1362
1363
1364
            elif self.get_model_name(target_model) == "PixtralForConditionalGeneration":
                self.model.config.image_token_index = (
                    target_model.config.vision_config.image_token_id
                )
1365
1366
            else:
                self.model.config.image_token_index = (
1367
1368
                    target_model.config.image_token_index
                )
1369
1370
1371
            target_language_model = cast(
                SupportsMultiModal, target_model
            ).get_language_model()
1372
1373
        else:
            target_language_model = target_model
1374

1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
        self._maybe_share_embeddings(target_language_model)
        self._maybe_share_lm_head(target_language_model)

        if self.parallel_drafting and self.pass_hidden_states_to_model:
            assert self.parallel_drafting_hidden_state_tensor is not None
            self.parallel_drafting_hidden_state_tensor.copy_(
                self.model.combine_hidden_states(
                    self.model.mask_hidden.view(3 * self.hidden_size)
                )
                if self.eagle3_use_aux_hidden_state
                else self.model.mask_hidden.view(self.hidden_size)
            )

    def _maybe_share_embeddings(self, target_language_model: nn.Module) -> None:
        """
        Some draft models may not have their own embedding layers, and some may
        have a duplicate copy of the target model's embedding layers. In these cases,
        we share the target model's embedding layers with the draft model to save
        memory.
        """
1395
        if get_pp_group().world_size == 1:
1396
1397
1398
1399
1400
1401
1402
            inner_model = getattr(target_language_model, "model", None)
            if inner_model is None:
                raise AttributeError("Target model does not have 'model' attribute")
            if hasattr(inner_model, "embed_tokens"):
                target_embed_tokens = inner_model.embed_tokens
            elif hasattr(inner_model, "embedding"):
                target_embed_tokens = inner_model.embedding
1403
1404
            else:
                raise AttributeError(
1405
1406
                    "Target model does not have 'embed_tokens' or 'embedding' attribute"
                )
1407

1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
            share_embeddings = False
            if hasattr(self.model, "has_own_embed_tokens"):
                # EAGLE model
                if not self.model.has_own_embed_tokens:
                    share_embeddings = True
                    logger.info(
                        "Detected EAGLE model without its own embed_tokens in the"
                        " checkpoint. Sharing target model embedding weights with the"
                        " draft model."
                    )
                elif (
                    isinstance(target_embed_tokens.weight, torch.Tensor)
                    and isinstance(self.model.model.embed_tokens.weight, torch.Tensor)
1421
1422
1423
                    # TODO: Offload to CPU for comparison to avoid extra GPU memory
                    # usage in CI testing environments with limited GPU memory
                    and torch.equal(
1424
1425
                        target_embed_tokens.weight.cpu(),
                        self.model.model.embed_tokens.weight.cpu(),
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
                    )
                ):
                    share_embeddings = True
                    logger.info(
                        "Detected EAGLE model with embed_tokens identical to the target"
                        " model. Sharing target model embedding weights with the draft"
                        " model."
                    )
                else:
                    logger.info(
                        "Detected EAGLE model with distinct embed_tokens weights. "
                        "Keeping separate embedding weights from the target model."
                    )
1439
            else:
1440
1441
                # MTP model
                share_embeddings = True
1442
                logger.info(
1443
1444
                    "Detected MTP model. "
                    "Sharing target model embedding weights with the draft model."
1445
                )
1446
1447
1448
1449
1450

            if share_embeddings:
                if hasattr(self.model.model, "embed_tokens"):
                    del self.model.model.embed_tokens
                self.model.model.embed_tokens = target_embed_tokens
1451
        else:
1452
            logger.info(
1453
                "The draft model's vocab embedding will be loaded separately"
1454
1455
                " from the target model."
            )
1456

1457
1458
1459
1460
1461
1462
    def _maybe_share_lm_head(self, target_language_model: nn.Module) -> None:
        """
        Some draft models may not have their own LM head, and some may have a
        duplicate copy of the target model's LM head. In these cases, we share
        the target model's LM head with the draft model to save memory.
        """
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
        share_lm_head = False
        if hasattr(self.model, "has_own_lm_head"):
            # EAGLE model
            if not self.model.has_own_lm_head:
                share_lm_head = True
                logger.info(
                    "Detected EAGLE model without its own lm_head in the checkpoint. "
                    "Sharing target model lm_head weights with the draft model."
                )
            elif (
                hasattr(target_language_model, "lm_head")
                and isinstance(target_language_model.lm_head.weight, torch.Tensor)
                and isinstance(self.model.lm_head.weight, torch.Tensor)
1476
1477
                # TODO: Offload to CPU for comparison to avoid extra GPU memory
                # usage in CI testing environments with limited GPU memory
1478
                and torch.equal(
1479
1480
                    target_language_model.lm_head.weight.cpu(),
                    self.model.lm_head.weight.cpu(),
1481
                )
1482
            ):
1483
                share_lm_head = True
1484
                logger.info(
1485
1486
                    "Detected EAGLE model with lm_head identical to the target model. "
                    "Sharing target model lm_head weights with the draft model."
1487
                )
1488
1489
            else:
                logger.info(
1490
1491
                    "Detected EAGLE model with distinct lm_head weights. "
                    "Keeping separate lm_head weights from the target model."
1492
                )
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
        else:
            # MTP model
            share_lm_head = True
            logger.info(
                "Detected MTP model. "
                "Sharing target model lm_head weights with the draft model."
            )

        if share_lm_head and hasattr(target_language_model, "lm_head"):
            if hasattr(self.model, "lm_head"):
                del self.model.lm_head
            self.model.lm_head = target_language_model.lm_head
1505

1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
            # MTP models call compute_logits via shared_head.head (a
            # ParallelLMHead inside each MTP layer), not self.model.lm_head.
            # If the checkpoint omits a copy of the lm_head weights at the
            # MTP layer path, shared_head.head stays uninitialised and
            # produces NaN logits. Always share it explicitly.
            inner = getattr(self.model, "model", None)
            layers = getattr(inner, "layers", None) if inner else None
            if layers is not None:
                items = layers.values() if isinstance(layers, nn.ModuleDict) else layers
                for layer in items:
                    sh = getattr(layer, "shared_head", None)
                    if sh is not None and hasattr(sh, "head"):
                        del sh.head
                        sh.head = target_language_model.lm_head
                        logger.info(
                            "Shared target model lm_head with MTP shared_head.head."
                        )

1524
1525
1526
1527
    @torch.inference_mode()
    def dummy_run(
        self,
        num_tokens: int,
1528
1529
        use_cudagraphs: bool = True,
        is_graph_capturing: bool = False,
1530
        slot_mappings: dict[str, torch.Tensor] | None = None,
1531
    ) -> None:
Rémi Delacourt's avatar
Rémi Delacourt committed
1532
1533
1534
1535
        # FIXME: when using tree-based specdec, adjust number of forward-passes
        # according to the depth of the tree.
        for fwd_idx in range(
            self.num_speculative_tokens if not is_graph_capturing else 1
1536
        ):
Rémi Delacourt's avatar
Rémi Delacourt committed
1537
1538
            if fwd_idx <= 1:
                num_tokens_dp_padded, num_tokens_across_dp = self._pad_batch_across_dp(
1539
                    num_tokens_unpadded=num_tokens, num_tokens_padded=num_tokens
Rémi Delacourt's avatar
Rémi Delacourt committed
1540
                )
1541
1542
1543
1544
1545
1546
1547
1548
                if use_cudagraphs:
                    cudagraph_runtime_mode, batch_desc = (
                        self.cudagraph_dispatcher.dispatch(num_tokens_dp_padded)
                    )
                    num_input_tokens = batch_desc.num_tokens
                else:
                    cudagraph_runtime_mode = CUDAGraphMode.NONE
                    num_input_tokens = num_tokens_dp_padded
Rémi Delacourt's avatar
Rémi Delacourt committed
1549
1550
                if num_tokens_across_dp is not None:
                    num_tokens_across_dp[self.dp_rank] = num_input_tokens
1551

1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
            # Make sure to use EAGLE's own buffer during cudagraph capture.
            if (
                self.attn_layer_names
                and slot_mappings is not None
                and self.attn_layer_names[0] in slot_mappings
            ):
                slot_mapping_dict = self._get_slot_mapping(num_input_tokens)
            else:
                slot_mapping_dict = slot_mappings or {}

Rémi Delacourt's avatar
Rémi Delacourt committed
1562
1563
1564
1565
1566
            with set_forward_context(
                None,
                self.vllm_config,
                num_tokens=num_input_tokens,
                num_tokens_across_dp=num_tokens_across_dp,
1567
                cudagraph_runtime_mode=cudagraph_runtime_mode,
1568
                slot_mapping=slot_mapping_dict,
Rémi Delacourt's avatar
Rémi Delacourt committed
1569
1570
1571
1572
1573
1574
1575
1576
            ):
                if self.supports_mm_inputs:
                    input_ids = None
                    inputs_embeds = self.inputs_embeds[:num_input_tokens]
                else:
                    input_ids = self.input_ids[:num_input_tokens]
                    inputs_embeds = None

1577
                kwargs = dict(
Rémi Delacourt's avatar
Rémi Delacourt committed
1578
1579
1580
1581
                    input_ids=input_ids,
                    positions=self._get_positions(num_input_tokens),
                    inputs_embeds=inputs_embeds,
                )
1582
1583
1584
                if self.pass_hidden_states_to_model:
                    kwargs["hidden_states"] = self.hidden_states[:num_input_tokens]
                self.model(**kwargs)
1585

1586
    def _get_attention_metadata_builder(self) -> AttentionMetadataBuilder:
1587
        """Find and return the attention metadata builders for EAGLE layers.
1588

1589
1590
        Returns:
            The metadata builders for EAGLE layers.
1591

1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
        Raises:
            AssertionError: If no metadata builders are found for EAGLE layers.
        """
        builder = None
        chosen_layer = self.attn_layer_names[0]

        for kv_cache_group in self.runner.attn_groups:
            for attn_group in kv_cache_group:
                if chosen_layer in attn_group.layer_names:
                    builder = attn_group.get_metadata_builder()
                    break
            if builder is not None:
                break

        assert builder is not None, (
1607
1608
            "Failed to find attention metadata builder for EAGLE layers."
        )
1609
1610
        return builder

1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
    def _get_eagle3_use_aux_hidden_state_from_config(self) -> bool:
        """
        Some eagle3 heads (e.g., nvidia/gpt-oss-120b-Eagle3-v2) do not use auxiliary
        hidden states and directly uses the last layer output just like eagle1.
        They might indicate this by setting "use_aux_hidden_state" to False
        inside the "eagle_config" dict of their hf_config.
        """
        if self.method != "eagle3":
            return False
        # Assume that eagle3 heads use aux hidden states by default
        use_aux_hidden_state = True
        eagle_config = getattr(self.draft_model_config.hf_config, "eagle_config", None)
        if eagle_config is not None:
            use_aux_hidden_state = eagle_config.get("use_aux_hidden_state", True)
        return use_aux_hidden_state

1627
    def validate_same_kv_cache_group(self, kv_cache_config: KVCacheConfig) -> None:
1628
        """
1629
1630
        Validate that all drafting layers belong to the same KVCacheGroup.
        Need this assumption to ensure all drafting layers can use the
1631
1632
1633
1634
1635
1636
1637
        same AttentionMetadata.
        May extend to multiple AttentionMetadata in the future.
        """
        kv_cache_groups: dict[str, int] = {}
        for id, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups):
            for layer_name in kv_cache_group.layer_names:
                kv_cache_groups[layer_name] = id
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
        assert (
            len(
                set(
                    [
                        kv_cache_groups[layer_name]
                        for layer_name in self.attn_layer_names
                    ]
                )
            )
            == 1
1648
        ), "All drafting layers should belong to the same kv cache group"
1649

Rémi Delacourt's avatar
Rémi Delacourt committed
1650
1651
1652
1653
1654
1655
    def _pad_batch_across_dp(
        self,
        num_tokens_unpadded: int,
        num_tokens_padded: int,
    ) -> tuple[int, torch.Tensor]:
        # TODO(Flechman): support DBO ubatching
1656
        should_ubatch, num_toks_across_dp, _ = coordinate_batch_across_dp(
Rémi Delacourt's avatar
Rémi Delacourt committed
1657
1658
1659
            num_tokens_unpadded=num_tokens_unpadded,
            parallel_config=self.vllm_config.parallel_config,
            allow_microbatching=False,
1660
1661
            allow_dp_padding=self.cudagraph_dispatcher.cudagraph_mode
            != CUDAGraphMode.NONE,
Rémi Delacourt's avatar
Rémi Delacourt committed
1662
1663
1664
1665
            num_tokens_padded=num_tokens_padded,
            uniform_decode=None,
            num_scheduled_tokens_per_request=None,
        )
1666
        assert not should_ubatch, "DBO ubatching not implemented for EAGLE"
Rémi Delacourt's avatar
Rémi Delacourt committed
1667
1668
1669
1670
1671
1672

        num_tokens_dp_padded = num_tokens_padded
        if num_toks_across_dp is not None:
            num_tokens_dp_padded = int(num_toks_across_dp[self.dp_rank].item())
        return num_tokens_dp_padded, num_toks_across_dp

1673

1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
class EagleProposer(SpecDecodeBaseProposer):
    def __init__(
        self,
        vllm_config: VllmConfig,
        device: torch.device,
        runner=None,
    ):
        super().__init__(
            vllm_config,
            device,
            pass_hidden_states_to_model=True,
            runner=runner,
        )


1689
1690
1691
1692
# NOTE(woosuk): Currently, the below code is not used and we always use argmax
# to sample the draft tokens. We will use this after we find a way to manage
# the draft prob tensor.
# Refer to https://github.com/vllm-project/vllm/pull/16899 for the details.
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
# FIXME(woosuk): The logic here is duplicated with the main sampling code.
# We should refactor this to reuse the same sampling implementation.
def compute_probs_and_sample_next_token(
    logits: torch.Tensor,
    sampling_metadata: SamplingMetadata,
) -> tuple[torch.Tensor, torch.Tensor]:
    if sampling_metadata.all_greedy:
        # For greedy requests, draft_probs is not used in rejection sampling.
        # Therefore, we can just return the logits.
        probs = logits
        next_token_ids = logits.argmax(dim=-1)
        return next_token_ids, probs

1706
1707
1708
1709
1710
1711
1712
1713
1714
    assert sampling_metadata.temperature is not None

    # Use epsilon comparison to detect greedy sampling (temperature ~ 0.0)
    # consistent with sampler.py's _SAMPLING_EPS threshold
    temperature = sampling_metadata.temperature
    # Avoid division by zero if there are greedy requests.
    if not sampling_metadata.all_random:
        is_greedy = temperature < _SAMPLING_EPS
        temperature = torch.where(is_greedy, 1.0, temperature)
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
    logits.div_(temperature.view(-1, 1))
    probs = logits.softmax(dim=-1, dtype=torch.float32)

    # NOTE(woosuk): Currently, we ignore most of the sampling parameters in
    # generating the draft tokens. We only use the temperature. While this
    # could degrade the acceptance rate, it does not affect the distribution
    # of the generated tokens after rejection sampling.

    # TODO(woosuk): Consider seeds.
    q = torch.empty_like(probs)
    q.exponential_()
1726
1727
1728
    # NOTE(woosuk): We shouldn't use `probs.div_(q)` because the draft_probs
    # will be used later for rejection sampling.
    next_token_ids = probs.div(q).argmax(dim=-1).view(-1)
1729
1730
    if not sampling_metadata.all_random:
        greedy_token_ids = probs.argmax(dim=-1)
1731
        next_token_ids = torch.where(is_greedy, greedy_token_ids, next_token_ids)
1732
    return next_token_ids, probs