eagle.py 78.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import ast
4
from importlib.util import find_spec
5
from typing import Any, cast
6

7
import numpy as np
8
9
10
import torch
import torch.nn as nn

11
12
13
14
from vllm.config import (
    CUDAGraphMode,
    VllmConfig,
    get_layers_from_vllm_config,
15
    replace,
16
)
17
from vllm.distributed.parallel_state import get_pp_group
18
from vllm.forward_context import set_forward_context
19
from vllm.logger import init_logger
20
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
21
from vllm.model_executor.model_loader import get_model
22
from vllm.model_executor.models import supports_multimodal
23
from vllm.model_executor.models.deepseek_eagle3 import Eagle3DeepseekV2ForCausalLM
24
from vllm.model_executor.models.interfaces import SupportsMultiModal
25
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
26
from vllm.model_executor.models.qwen3_dflash import DFlashQwen3ForCausalLM
27
from vllm.multimodal import MULTIMODAL_REGISTRY
28
from vllm.platforms import current_platform
29
from vllm.triton_utils import triton
30
from vllm.utils.platform_utils import is_pin_memory_available
31
from vllm.v1.attention.backend import CommonAttentionMetadata
32
from vllm.v1.attention.backends.registry import AttentionBackendEnum
33
34
35
36
from vllm.v1.attention.backends.tree_attn import (
    TreeAttentionMetadata,
    TreeAttentionMetadataBuilder,
)
37
from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata
38
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
39
from vllm.v1.kv_cache_interface import KVCacheConfig, UniformTypeKVCacheSpecs
40
from vllm.v1.sample.metadata import SamplingMetadata
41
from vllm.v1.sample.sampler import _SAMPLING_EPS
42
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
43
from vllm.v1.spec_decode.utils import (
44
45
46
    PADDING_SLOT_ID,
    compute_new_slot_mapping,
    copy_and_expand_eagle_inputs_kernel,
47
48
    eagle_prepare_inputs_padded_kernel,
    eagle_prepare_next_token_padded_kernel,
49
    eagle_step_update_slot_mapping_and_metadata,
50
    extend_all_queries_by_N,
51
)
52
from vllm.v1.utils import CpuGpuBuffer
Rémi Delacourt's avatar
Rémi Delacourt committed
53
from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
54
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
55
from vllm.v1.worker.utils import AttentionGroup
56

57
58
logger = init_logger(__name__)

59

60
class SpecDecodeBaseProposer:
61
62
63
64
    def __init__(
        self,
        vllm_config: VllmConfig,
        device: torch.device,
65
        pass_hidden_states_to_model: bool,
Jiayi Yao's avatar
Jiayi Yao committed
66
        runner=None,
67
68
    ):
        self.vllm_config = vllm_config
69
        assert vllm_config.speculative_config is not None
70
71
72
        self.speculative_config = vllm_config.speculative_config
        self.draft_model_config = self.speculative_config.draft_model_config
        self.method = self.speculative_config.method
73
        self.pass_hidden_states_to_model = pass_hidden_states_to_model
74

75
        self.device = device
76
        self.dtype = vllm_config.model_config.dtype
77
        self.max_model_len = vllm_config.model_config.max_model_len
Rémi Delacourt's avatar
Rémi Delacourt committed
78
        self.dp_rank = vllm_config.parallel_config.data_parallel_rank
79
        self.num_speculative_tokens = self.speculative_config.num_speculative_tokens
80

81
82
83
84
        # We need to get the hidden size from the draft model config because
        # the draft model's hidden size can be different from the target model's
        # hidden size (e.g., Llama 3.3 70B).
        self.hidden_size = self.draft_model_config.get_hidden_size()
85
        self.inputs_embeds_size = self.draft_model_config.get_inputs_embeds_size()
86

87
88
89
        # Unifying eagle, draft model, and parallel drafting support.
        # DFlash always uses parallel drafting (all tokens in one pass),
        # but has an additional slot for the next_token_id (does not shift like EAGLE)
90
91
92
93
94
        self.parallel_drafting: bool = self.speculative_config.parallel_drafting
        self.extra_slots_per_request = (
            1 if not self.parallel_drafting else self.num_speculative_tokens
        )
        self.net_num_new_slots_per_request = self.extra_slots_per_request - (
95
            1 if (self.pass_hidden_states_to_model and self.method != "dflash") else 0
96
97
98
99
100
101
102
        )
        self.needs_extra_input_slots = self.net_num_new_slots_per_request > 0

        self.parallel_drafting_token_id: int = 0
        self.parallel_drafting_hidden_state_tensor: torch.Tensor | None = None
        if self.parallel_drafting:
            self._init_parallel_drafting_params()
103
104
105
        self.use_local_argmax_reduction: bool = (
            self.speculative_config.use_local_argmax_reduction
        )
106

107
        self.max_batch_size = vllm_config.scheduler_config.max_num_seqs
108
        self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
109
110
        self.token_arange_np = np.arange(self.max_num_tokens)

111
112
113
114
        # Can be specialized by methods like DFlash to reduce the limit
        self.max_query_tokens = self.max_num_tokens
        self.max_positions = self.max_num_tokens

115
116
117
        # Multi-modal data support
        self.mm_registry = MULTIMODAL_REGISTRY
        self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
118
            vllm_config.model_config
119
        )
120

121
122
        self.draft_attn_groups: list[AttentionGroup] = []
        self.kv_cache_gid: int = -1
123
124
125
        self.eagle3_use_aux_hidden_state: bool = (
            self._get_eagle3_use_aux_hidden_state_from_config()
        )
126

127
        self.compilation_config = self.vllm_config.compilation_config
128
129
130
131
132
133

        # Cudagraph dispatcher for PIECEWISE-only dispatching in eagle.
        # Keys are initialized later via initialize_cudagraph_keys() called from
        # gpu_model_runner._check_and_update_cudagraph_mode after
        # adjust_cudagraph_sizes_for_spec_decode is called.
        self.cudagraph_dispatcher = CudagraphDispatcher(self.vllm_config)
134

135
        # persistent buffers for cuda graph
136
137
138
        self.input_ids = torch.zeros(
            self.max_num_tokens, dtype=torch.int32, device=device
        )
139
140
141
        # Use draft model's M-RoPE setting, not target model's
        # Draft models may be text-only even if target is multimodal
        self.uses_mrope = self.draft_model_config.uses_mrope
142
143
        self.uses_xdrope_dim = self.vllm_config.model_config.uses_xdrope_dim
        self.draft_uses_xdrope_dim = self.draft_model_config.uses_xdrope_dim
144
        if self.uses_mrope:
145
146
147
148
149
150
151
152
153
154
            # NOTE: `mrope_positions` is implemented with one additional dummy
            # position on purpose to make it non-contiguous so that it can work
            # with torch compile.
            # See detailed explanation in https://github.com/vllm-project/vllm/pull/12128#discussion_r1926431923

            # NOTE: When M-RoPE is enabled, position ids are 3D regardless of
            # the modality of inputs. For text-only inputs, each dimension has
            # identical position IDs, making M-RoPE functionally equivalent to
            # 1D-RoPE.
            # See page 5 of https://arxiv.org/abs/2409.12191
155
            self.mrope_positions = torch.zeros(
156
                (3, self.max_positions + 1), dtype=torch.int64, device=device
157
            )
158
159
        elif self.uses_xdrope_dim > 0 and self.draft_uses_xdrope_dim > 0:
            self.xdrope_positions = torch.zeros(
160
                (self.uses_xdrope_dim, self.max_positions + 1),
161
162
163
                dtype=torch.int64,
                device=device,
            )
164
165
        else:
            # RoPE need (max_num_tokens,)
166
            self.positions = torch.zeros(
167
168
169
                self.max_positions,
                dtype=torch.int64,
                device=device,
170
            )
171
        self.hidden_states = torch.zeros(
172
173
            (self.max_num_tokens, self.hidden_size), dtype=self.dtype, device=device
        )
174

175
176
177
        # Will be set when we initialize the attention backend
        self.block_size: int = -1

178
179
        # We need +1 here because the arange is used to set query_start_loc,
        # which has one more element than batch_size.
180
        max_num_slots_for_arange = max(self.max_batch_size + 1, self.max_num_tokens)
181
182
183
        self.arange = torch.arange(
            max_num_slots_for_arange, device=device, dtype=torch.int32
        )
184

185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
        if self.needs_extra_input_slots:
            self._raise_if_padded_drafter_batch_disabled()
            self._raise_if_multimodal()
            self._raise_if_mrope()

        self.is_rejected_token_mask: torch.Tensor | None = None
        self.is_masked_token_mask: torch.Tensor | None = None
        if self.needs_extra_input_slots:
            # For draft models and parallel drafting, we need to keep track of
            # which tokens are rejected to update the slot mapping with padding slots.
            self.is_rejected_token_mask = torch.zeros(
                (self.max_num_tokens,), dtype=torch.bool, device=device
            )
            # For parallel drafting, we also need to keep track of which tokens
            # are parallel-padding tokens used to sample at later positions.
            # We populate this tensor even when using draft models for simplicity.
            self.is_masked_token_mask = torch.zeros(
                (self.max_num_tokens,), dtype=torch.bool, device=device
            )

205
        self.inputs_embeds = torch.zeros(
206
207
208
            (self.max_num_tokens, self.inputs_embeds_size),
            dtype=self.dtype,
            device=device,
209
        )
210

211
        self.backup_next_token_ids = CpuGpuBuffer(
212
            self.max_batch_size,
213
214
215
            dtype=torch.int32,
            pin_memory=is_pin_memory_available(),
            device=device,
216
217
            with_numpy=True,
        )
218

219
        self._slot_mapping_buffer = torch.zeros(
220
221
222
            self.max_positions,
            dtype=torch.int64,
            device=device,
223
224
        )

225
        # Determine allowed attention backends once during initialization.
226
        self.allowed_attn_types: tuple | None = None
227
        if current_platform.is_rocm():
228
229
230
            from vllm.v1.attention.backends.mla.rocm_aiter_mla_sparse import (
                ROCMAiterMLASparseMetadata,
            )
231
232
233
234
235
            from vllm.v1.attention.backends.rocm_attn import RocmAttentionMetadata

            rocm_types = [
                TritonAttentionMetadata,
                RocmAttentionMetadata,
236
                ROCMAiterMLASparseMetadata,
237
            ]
238
            # ROCM_AITER_FA is an optional backend
239
240
241
242
243
244
245
            # We check is_enabled() here to avoid importing the backend module during
            # auto-discovery when VLLM_ROCM_USE_AITER=0, which would trigger aiter
            # import and JIT compilation warnings. Explicit backend selection via
            # attention_config still works because the backend module is loaded
            # directly when selected, not through this auto-discovery path.
            # Check if backend module exists to allow explicit selection
            if find_spec(
246
247
                AttentionBackendEnum.ROCM_AITER_FA.get_path(include_classname=False)
            ):
248
                from vllm.v1.attention.backends.rocm_aiter_fa import (
249
250
251
                    AiterFlashAttentionMetadata,
                )

252
                rocm_types.append(AiterFlashAttentionMetadata)
253
254

            # TRITON_MLA backend support for MLA models (e.g., DeepSeek)
255
256
257
            from vllm.model_executor.layers.attention.mla_attention import (
                MLACommonMetadata,
            )
258
259
260

            rocm_types.append(MLACommonMetadata)

261
262
263
264
265
            # FlexAttention backend support
            from vllm.v1.attention.backends.flex_attention import FlexAttentionMetadata

            rocm_types.append(FlexAttentionMetadata)

266
267
            self.allowed_attn_types = tuple(rocm_types)

268
269
        # Parse the speculative token tree.
        spec_token_tree = self.speculative_config.speculative_token_tree
270
        assert spec_token_tree is not None
271
        self.tree_choices: list[tuple[int, ...]] = ast.literal_eval(spec_token_tree)
272
273
274
275
276
277
278
279
        tree_depth = len(self.tree_choices[-1])
        # Precompute per-level properties of the tree.
        num_drafts_per_level = [0] * tree_depth
        for node in self.tree_choices:
            num_drafts_per_level[len(node) - 1] += 1
        self.cu_drafts_per_level = [num_drafts_per_level[0]]
        self.child_drafts_per_level = [num_drafts_per_level[0]]
        for level in range(1, tree_depth):
280
281
282
283
284
285
            self.cu_drafts_per_level.append(
                self.cu_drafts_per_level[-1] + num_drafts_per_level[level]
            )
            self.child_drafts_per_level.append(
                num_drafts_per_level[level] // num_drafts_per_level[level - 1]
            )
286
287
        # Precompute draft position offsets in flattened tree.
        self.tree_draft_pos_offsets = torch.arange(
288
            1, len(self.tree_choices) + 1, device=device, dtype=torch.int32
289
        ).repeat(self.max_batch_size, 1)
290

291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
    def _raise_if_padded_drafter_batch_disabled(self):
        if self.speculative_config.disable_padded_drafter_batch:
            raise NotImplementedError(
                "Speculative Decoding with draft models or parallel drafting only "
                "supports padded drafter batch. Please unset "
                "disable_padded_drafter_batch in the speculative_config."
            )

    def _raise_if_multimodal(self):
        if self.supports_mm_inputs:
            raise NotImplementedError(
                "Speculative Decoding with draft models or parallel drafting "
                "does not support multimodal models yet"
            )

    def _raise_if_mrope(self):
        if self.draft_model_config.uses_mrope:
            raise NotImplementedError(
                "Speculative Decoding with draft models or parallel drafting "
                "does not support M-RoPE yet"
            )

    def _init_parallel_drafting_params(self):
        # For parallel drafting, we need the token ID to use for masked slots
        # And for EAGLE + parallel drafting, we need the hidden state tensor to use
        # for those masked slots.

        model_hf_config = self.draft_model_config.hf_config
319
320
321
322
323
        # DFlash stores mask_token_id in dflash_config
        dflash_config = getattr(model_hf_config, "dflash_config", None)
        if dflash_config and "mask_token_id" in dflash_config:
            self.parallel_drafting_token_id = dflash_config["mask_token_id"]
        elif hasattr(model_hf_config, "pard_token"):
324
325
326
327
328
329
            self.parallel_drafting_token_id = model_hf_config.pard_token
        elif hasattr(model_hf_config, "ptd_token_id"):
            self.parallel_drafting_token_id = model_hf_config.ptd_token_id
        else:
            raise ValueError(
                "For parallel drafting, the draft model config must have "
330
331
                "`pard_token`, `ptd_token_id`, or "
                "`dflash_config.mask_token_id` specified in its config.json."
332
333
334
335
336
337
338
            )

        if self.pass_hidden_states_to_model:
            self.parallel_drafting_hidden_state_tensor = torch.empty(
                self.hidden_size, dtype=self.dtype, device=self.device
            )

339
340
341
    def _get_positions(self, num_tokens: int):
        if self.uses_mrope:
            return self.mrope_positions[:, :num_tokens]
342
343
        if self.uses_xdrope_dim > 0 and self.draft_uses_xdrope_dim > 0:
            return self.xdrope_positions[:, :num_tokens]
344
345
346
347
348
        return self.positions[:num_tokens]

    def _set_positions(self, num_tokens: int, positions: torch.Tensor):
        if self.uses_mrope:
            self.mrope_positions[:, :num_tokens] = positions
349
350
        elif self.uses_xdrope_dim > 0 and self.draft_uses_xdrope_dim > 0:
            self.xdrope_positions[:, :num_tokens] = positions
351
        else:
352
353
354
355
356
            # Convert M-RoPE positions if target model uses M-RoPE
            # but draft doesn't, For text inputs, all M-RoPE
            # dimensions are identical
            if self.vllm_config.model_config.uses_mrope:
                positions = positions[0]
357
358
            self.positions[:num_tokens] = positions

359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
    def _get_slot_mapping(
        self,
        num_tokens: int,
        slot_mapping: torch.Tensor | None = None,
    ) -> dict[str, torch.Tensor]:
        """Return slot_mapping dict for EAGLE layers.

        If slot_mapping is provided, copies it into the buffer first.
        """
        if slot_mapping is not None:
            num_actual = slot_mapping.shape[0]
            self._slot_mapping_buffer[:num_actual].copy_(slot_mapping)
            if num_tokens > num_actual:
                self._slot_mapping_buffer[num_actual:num_tokens].fill_(PADDING_SLOT_ID)

        view = self._slot_mapping_buffer[:num_tokens]
375
        return {name: view for name in self._draft_attn_layer_names}
376

377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
    def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode) -> None:
        """Initialize cudagraph dispatcher keys for eagle.

        Eagle only supports PIECEWISE cudagraphs (via mixed_mode).
        This should be called after adjust_cudagraph_sizes_for_spec_decode.
        """
        if (
            not self.speculative_config.enforce_eager
            and cudagraph_mode.mixed_mode()
            in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL]
        ):
            eagle_cudagraph_mode = CUDAGraphMode.PIECEWISE
        else:
            eagle_cudagraph_mode = CUDAGraphMode.NONE

        self.cudagraph_dispatcher.initialize_cudagraph_keys(eagle_cudagraph_mode)

394
395
396
397
398
399
    def _greedy_sample(self, hidden_states: torch.Tensor) -> torch.Tensor:
        """Greedy-sample draft tokens from hidden states."""
        if self.use_local_argmax_reduction:
            return self.model.get_top_tokens(hidden_states)
        return self.model.compute_logits(hidden_states).argmax(dim=-1)

400
401
402
403
    def propose(
        self,
        # [num_tokens]
        target_token_ids: torch.Tensor,
404
        # [num_tokens] or [3, num_tokens] when M-RoPE is enabled
405
406
407
408
409
        target_positions: torch.Tensor,
        # [num_tokens, hidden_size]
        target_hidden_states: torch.Tensor,
        # [batch_size]
        next_token_ids: torch.Tensor,
410
        token_indices_to_sample: torch.Tensor | None,
411
        common_attn_metadata: CommonAttentionMetadata,
412
        sampling_metadata: SamplingMetadata,
413
        mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None,
414
        num_rejected_tokens_gpu: torch.Tensor | None = None,
415
416
417
        slot_mappings: dict[str, torch.Tensor]
        | list[dict[str, torch.Tensor]]
        | None = None,
418
    ) -> torch.Tensor:
419
        batch_size = common_attn_metadata.batch_size()
420

421
        if self.method in ("eagle3", "dflash"):
422
            assert isinstance(
423
424
425
426
427
428
                self.model,
                (
                    Eagle3LlamaForCausalLM,
                    Eagle3DeepseekV2ForCausalLM,
                    DFlashQwen3ForCausalLM,
                ),
429
            )
430
            target_hidden_states = self.model.combine_hidden_states(
431
432
                target_hidden_states
            )
433
            assert target_hidden_states.shape[-1] == self.hidden_size
434

435
        num_tokens, token_indices_to_sample, common_attn_metadata = (
436
437
438
439
            self.set_inputs_first_pass(
                target_token_ids=target_token_ids,
                next_token_ids=next_token_ids,
                target_positions=target_positions,
440
441
                target_hidden_states=target_hidden_states,
                token_indices_to_sample=token_indices_to_sample,
442
443
444
445
                cad=common_attn_metadata,
                num_rejected_tokens_gpu=num_rejected_tokens_gpu,
            )
        )
446

447
448
449
        per_layer_attn_metadata = self.build_per_layer_attn_metadata(
            common_attn_metadata
        )
450

451
452
        cudagraph_runtime_mode, num_input_tokens, num_tokens_across_dp = (
            self._determine_batch_execution_and_padding(num_tokens)
Rémi Delacourt's avatar
Rémi Delacourt committed
453
454
        )

455
456
457
        model_kwargs, slot_mapping_size = self.build_model_inputs_first_pass(
            num_tokens, num_input_tokens, mm_embed_inputs
        )
458

459
        with set_forward_context(
460
461
462
            per_layer_attn_metadata,
            self.vllm_config,
            num_tokens=num_input_tokens,
Rémi Delacourt's avatar
Rémi Delacourt committed
463
            num_tokens_across_dp=num_tokens_across_dp,
464
            cudagraph_runtime_mode=cudagraph_runtime_mode,
465
            slot_mapping=self._get_slot_mapping(
466
                slot_mapping_size, common_attn_metadata.slot_mapping
467
            ),
468
        ):
469
470
            ret_hidden_states = self.model(**model_kwargs)
            if not self.model_returns_tuple():
Jiayi Yao's avatar
Jiayi Yao committed
471
                last_hidden_states = ret_hidden_states
472
                hidden_states = last_hidden_states
Jiayi Yao's avatar
Jiayi Yao committed
473
474
            else:
                last_hidden_states, hidden_states = ret_hidden_states
475

476
        sample_hidden_states = last_hidden_states[token_indices_to_sample]
477
478

        # Early exit if there is only one draft token to be generated.
479
        if self.num_speculative_tokens == 1 or self.parallel_drafting:
480
            draft_token_ids = self._greedy_sample(sample_hidden_states)
481
            return draft_token_ids.view(-1, self.num_speculative_tokens)
482

483
        if self.uses_mrope:
484
            positions = self.mrope_positions[:, token_indices_to_sample]
485
        else:
486
            positions = self.positions[token_indices_to_sample]
487
        hidden_states = hidden_states[token_indices_to_sample]
488

489
490
491
492
        if any(
            isinstance(attn_metadata, TreeAttentionMetadata)
            for attn_metadata in per_layer_attn_metadata.values()
        ):
493
494
            # Draft using tree attention - requires full logits for top-k
            logits = self.model.compute_logits(sample_hidden_states)
495
496
497
498
499
500
            draft_token_ids_list = self.propose_tree(
                batch_size=batch_size,
                logits=logits,
                positions=positions,
                hidden_states=hidden_states,
                common_attn_metadata=common_attn_metadata,
501
                slot_mappings=slot_mappings,
502
503
504
505
            )
            # [batch_size, num_tree_tokens]
            return torch.cat(draft_token_ids_list, dim=1)

506
        draft_token_ids = self._greedy_sample(sample_hidden_states)
507

508
509
510
511
512
513
514
515
516
517
        for attn_metadata in per_layer_attn_metadata.values():
            if self.allowed_attn_types is not None and not isinstance(
                attn_metadata, self.allowed_attn_types
            ):
                raise ValueError(
                    f"Unsupported attention metadata type for speculative "
                    "decoding with num_speculative_tokens > 1: "
                    f"{type(attn_metadata)}. Supported types are: "
                    f"{self.allowed_attn_types}"
                )
518

519
520
521
        # Generate the remaining draft tokens.
        draft_token_ids_list = [draft_token_ids]

522
523
        cudagraph_runtime_mode, input_batch_size, batch_size_across_dp = (
            self._determine_batch_execution_and_padding(batch_size)
524
        )
525
526
527

        common_attn_metadata.num_actual_tokens = batch_size
        common_attn_metadata.max_query_len = 1
528
        common_attn_metadata.query_start_loc = self.arange[: batch_size + 1]
529
        common_attn_metadata.query_start_loc_cpu = torch.from_numpy(
530
531
            self.token_arange_np[: batch_size + 1]
        ).clone()
532
533
534
535
536
537
538
539
540
541
542

        # In padded drafter batch, we need to adjust the sequence lengths
        # to remove the "padding" (i.e. rejected tokens).
        # Only apply this adjustment when we have rejected tokens
        # (i.e., not the first proposal).
        if self.num_speculative_tokens > 1 and num_rejected_tokens_gpu is not None:
            common_attn_metadata.seq_lens -= num_rejected_tokens_gpu
            # Invalidate the CPU-side shadows to avoid H<>D sync.
            common_attn_metadata._seq_lens_cpu = None
            common_attn_metadata._num_computed_tokens_cpu = None

543
544
        block_size = self.block_size
        assert block_size > 0, "block_size has not been initialized."
545
        for token_index in range(self.num_speculative_tokens - 1):
546
            # Update the inputs.
547
548
549
            # cast to int32 is crucial when eagle model is compiled.
            # tensor.argmax() returns int64 by default.
            input_ids = draft_token_ids_list[-1].int()
550
551
552
553
            # Use fused kernel for slot mapping and metadata updates.
            # Write clamped positions directly into the positions buffer to
            # avoid an extra D2D copy for the common (non-mrope) case.
            positions_1d = positions[0] if self.uses_mrope else positions
554
            if self.uses_mrope:
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
                out_pos = self.mrope_positions[0, :batch_size]
            elif self.uses_xdrope_dim > 0 and self.draft_uses_xdrope_dim > 0:
                out_pos = self.xdrope_positions[0, :batch_size]
            else:
                out_pos = self.positions[:batch_size]
            eagle_step_update_slot_mapping_and_metadata(
                positions_1d=positions_1d,
                block_table_tensor=common_attn_metadata.block_table_tensor,
                seq_lens=common_attn_metadata.seq_lens,
                block_size=block_size,
                max_model_len=self.max_model_len,
                out_clamped_positions=out_pos,
                out_slot_mapping=self._slot_mapping_buffer[:input_batch_size],
                input_batch_size=input_batch_size,
            )
            common_attn_metadata.slot_mapping = self._slot_mapping_buffer[:batch_size]
            if self.uses_mrope:
                self.mrope_positions[1:, :batch_size] = self.mrope_positions[
                    0, :batch_size
                ]
                positions = self.mrope_positions[:, :batch_size]
            elif self.uses_xdrope_dim > 0 and self.draft_uses_xdrope_dim > 0:
                self.xdrope_positions[1:, :batch_size] = self.xdrope_positions[
                    0, :batch_size
                ]
                positions = self.xdrope_positions[0, :batch_size]
581
            else:
582
                positions = self.positions[:batch_size]
583
584
585
586
587
588
            # Increment the maximum sequence length. We increment max_seq_len
            # unconditionally even though some seq_lens may have been capped above,
            # as max_seq_len serves as an upper bound for sequence lengths.
            common_attn_metadata.max_seq_len = min(
                common_attn_metadata.max_seq_len + 1, self.max_model_len
            )
589

590
591
592
593
594
595
            # Also update the CPU-side shadow; NOTE: this is hacky and should be
            # removed in when common_attn_metadata.seq_lens_cpu is deprecated.
            if common_attn_metadata._seq_lens_cpu is not None:
                common_attn_metadata._seq_lens_cpu += 1
            if common_attn_metadata._num_computed_tokens_cpu is not None:
                common_attn_metadata._num_computed_tokens_cpu += 1
596

597
            # Rebuild attention metadata
598
599
600
            per_layer_attn_metadata = self.build_per_layer_attn_metadata(
                common_attn_metadata, draft_index=token_index + 1
            )
601

602
603
            # copy inputs to buffer for cudagraph
            self.input_ids[:batch_size] = input_ids
604
            self.hidden_states[:batch_size] = hidden_states
605
            if self.supports_mm_inputs:
606
                self.inputs_embeds[:batch_size] = self.model.embed_input_ids(input_ids)
607

608
                input_ids = None
609
                inputs_embeds = self.inputs_embeds[:input_batch_size]
610
611
            else:
                input_ids = self.input_ids[:input_batch_size]
612
                inputs_embeds = None
613

614
            # Run the model.
615
616
617
618
619
620
621
622
            model_kwargs = {
                "input_ids": input_ids,
                "positions": self._get_positions(input_batch_size),
                "inputs_embeds": inputs_embeds,
            }
            if self.pass_hidden_states_to_model:
                model_kwargs["hidden_states"] = self.hidden_states[:input_batch_size]

623
            with set_forward_context(
624
625
626
                per_layer_attn_metadata,
                self.vllm_config,
                num_tokens=input_batch_size,
Rémi Delacourt's avatar
Rémi Delacourt committed
627
                num_tokens_across_dp=batch_size_across_dp,
628
                cudagraph_runtime_mode=cudagraph_runtime_mode,
629
                slot_mapping=self._get_slot_mapping(input_batch_size),
630
            ):
631
632
                ret_hidden_states = self.model(**model_kwargs)
                if not self.model_returns_tuple():
633
634
635
636
                    last_hidden_states = ret_hidden_states
                    hidden_states = ret_hidden_states
                else:
                    last_hidden_states, hidden_states = ret_hidden_states
637

638
            hidden_states = hidden_states[:batch_size]
639
            draft_token_ids = self._greedy_sample(last_hidden_states[:batch_size])
640
641
642
643
            draft_token_ids_list.append(draft_token_ids)

        # [batch_size, num_speculative_tokens]
        draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
644
        return draft_token_ids
645

646
647
648
649
650
    def set_inputs_first_pass(
        self,
        target_token_ids: torch.Tensor,
        next_token_ids: torch.Tensor,
        target_positions: torch.Tensor,
651
652
        target_hidden_states: torch.Tensor,
        token_indices_to_sample: torch.Tensor | None,
653
654
655
        cad: CommonAttentionMetadata,
        num_rejected_tokens_gpu: torch.Tensor | None,
    ) -> tuple[int, torch.Tensor, CommonAttentionMetadata]:
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
        if not self.needs_extra_input_slots:
            # Default EAGLE pathway: no reshaping of input tensors needed.
            # Simply rotate the input ids and leave the positions unchanged,
            # Inserting the next token ids at the last slot in each request.
            if token_indices_to_sample is None:
                token_indices_to_sample = cad.query_start_loc[1:] - 1

            num_tokens = target_token_ids.shape[0]
            # Shift the input ids by one token.
            # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
            self.input_ids[: num_tokens - 1] = target_token_ids[1:]
            # Replace the last token with the next token.
            # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
            self.input_ids[token_indices_to_sample] = next_token_ids

            # copy inputs to buffer for cudagraph
            if self.uses_xdrope_dim > 0 and self.draft_uses_xdrope_dim == 0:
                target_positions = target_positions[0]
            self._set_positions(num_tokens, target_positions)

            self.hidden_states[:num_tokens] = target_hidden_states

            return num_tokens, token_indices_to_sample, cad
        else:
            assert self.is_rejected_token_mask is not None
            assert self.is_masked_token_mask is not None
            # 1.
            # Call a custom triton kernel to copy input_ids and positions
            # into the correct slots in the preallocated buffers self.input_ids,
            # self.positions.
            batch_size = cad.batch_size()
            # Since we might have to copy a lot of data for prefills, we select the
            # block size based on the max query length and limit to max 256 slots/block.
            max_num_tokens_per_request = (
                cad.max_query_len + self.net_num_new_slots_per_request
            )
            BLOCK_SIZE_TOKENS = min(
                256, triton.next_power_of_2(max_num_tokens_per_request)
            )
            num_blocks = (
                max_num_tokens_per_request + BLOCK_SIZE_TOKENS - 1
            ) // BLOCK_SIZE_TOKENS
            total_num_input_tokens = target_token_ids.shape[0]
            total_num_output_tokens = total_num_input_tokens + (
                self.net_num_new_slots_per_request * batch_size
            )
702

703
704
705
706
707
            token_indices_to_sample = torch.empty(
                batch_size * self.extra_slots_per_request,
                dtype=torch.int32,
                device=self.device,
            )
708

709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
            # Destination indices to write target_hidden_states into drafting buffer.
            out_hidden_state_mapping = torch.empty(
                total_num_input_tokens, dtype=torch.int32, device=self.device
            )

            # Kernel grid: one program per request (row)
            grid = (batch_size, num_blocks)
            query_start_loc = cad.query_start_loc
            query_end_loc = cad.query_start_loc[1:] - 1
            if num_rejected_tokens_gpu is not None:
                query_end_loc = query_end_loc - num_rejected_tokens_gpu
            copy_and_expand_eagle_inputs_kernel[grid](
                # (Padded) Inputs from the target model
                target_token_ids_ptr=target_token_ids,
                target_positions_ptr=target_positions,
                next_token_ids_ptr=next_token_ids,  # sampled tokens, one per request
                # Outputs to the drafting buffers
                out_input_ids_ptr=self.input_ids,
                out_positions_ptr=self.positions,  # Doesn't support mrope for now
                out_is_rejected_token_mask_ptr=self.is_rejected_token_mask,
                out_is_masked_token_mask_ptr=self.is_masked_token_mask,
                out_new_token_indices_ptr=token_indices_to_sample,
                out_hidden_state_mapping_ptr=out_hidden_state_mapping,
                # Input metadata
                query_start_loc_ptr=query_start_loc,
                query_end_loc_ptr=query_end_loc,
                padding_token_id=0,
                parallel_drafting_token_id=self.parallel_drafting_token_id,
                # Sizing info
                # Note that we can deduce batch_size for free from the grid size
                total_input_tokens=total_num_input_tokens,
                num_padding_slots_per_request=self.extra_slots_per_request,
                shift_input_ids=self.pass_hidden_states_to_model,
                BLOCK_SIZE_TOKENS=BLOCK_SIZE_TOKENS,
            )
            if self.pass_hidden_states_to_model:
                assert self.parallel_drafting_hidden_state_tensor is not None
                self.hidden_states[out_hidden_state_mapping] = target_hidden_states
                # Use torch.where to avoid DtoH sync from boolean indexing
                mask = self.is_masked_token_mask[:total_num_output_tokens]
                torch.where(
                    mask.unsqueeze(1),
                    self.parallel_drafting_hidden_state_tensor,
                    self.hidden_states[:total_num_output_tokens],
                    out=self.hidden_states[:total_num_output_tokens],
                )

            # 2.
            # Recompute the slot mapping based on the new positions and
            # rejection mask.
759
            assert self.block_size > 0, "block_size has not been initialized."
760
761
762
763
764
765
            new_slot_mapping = compute_new_slot_mapping(
                cad=cad,
                new_positions=self.positions[:total_num_output_tokens],
                is_rejected_token_mask=self.is_rejected_token_mask[
                    :total_num_output_tokens
                ],
766
                block_size=self.block_size,
767
768
769
770
771
772
773
774
775
776
777
                num_new_tokens=self.net_num_new_slots_per_request,
                max_model_len=self.max_model_len,
            )

            # 3. Update the common attention metadata with the new (meta)data
            new_cad = extend_all_queries_by_N(
                cad,
                N=self.net_num_new_slots_per_request,
                arange=self.arange,
                new_slot_mapping=new_slot_mapping,
            )
778

779
            return total_num_output_tokens, token_indices_to_sample, new_cad
780

781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
    def build_model_inputs_first_pass(
        self,
        num_tokens: int,
        num_input_tokens: int,
        mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None,
    ) -> tuple[dict[str, Any], int]:
        if self.supports_mm_inputs:
            mm_embeds, is_mm_embed = mm_embed_inputs or (None, None)

            self.inputs_embeds[:num_tokens] = self.model.embed_input_ids(
                self.input_ids[:num_tokens],
                multimodal_embeddings=mm_embeds,
                is_multimodal=is_mm_embed,
            )

            input_ids = None
            inputs_embeds = self.inputs_embeds[:num_input_tokens]
        else:
            input_ids = self.input_ids[:num_input_tokens]
            inputs_embeds = None

        model_kwargs = {
            "input_ids": input_ids,
            "positions": self._get_positions(num_input_tokens),
            "inputs_embeds": inputs_embeds,
        }
        if self.pass_hidden_states_to_model:
            model_kwargs["hidden_states"] = self.hidden_states[:num_input_tokens]

        return model_kwargs, num_input_tokens

    def build_per_layer_attn_metadata(
        self, common_attn_metadata: CommonAttentionMetadata, draft_index: int = 0
    ) -> dict[str, object]:
        per_layer_attn_metadata: dict[str, object] = {}
        for attn_group in self.draft_attn_groups:
            attn_metadata = attn_group.get_metadata_builder().build_for_drafting(
                common_attn_metadata=common_attn_metadata, draft_index=draft_index
            )
            for layer_name in attn_group.layer_names:
                per_layer_attn_metadata[layer_name] = attn_metadata
        return per_layer_attn_metadata

824
    def model_returns_tuple(self) -> bool:
825
        return self.method not in ("mtp", "draft_model", "dflash")
826

827
    def prepare_next_token_ids_cpu(
828
        self,
829
        sampled_token_ids: list[list[int]],
830
831
832
833
        requests: dict[str, CachedRequestState],
        gpu_input_batch: InputBatch,
        num_scheduled_tokens: dict[str, int],
    ) -> torch.Tensor:
834
835
836
837
838
839
840
841
842
843
        """
        This function is used to prepare the inputs for speculative decoding.
        It calculates the next token ids for each request based on the sampled
        token ids from the CPU. If a request has no sampled token ids (e.g.,
        during the initial decoding steps), it falls back to using the request
        state to get the next token id.
        """
        req_ids = gpu_input_batch.req_ids
        next_token_ids: list[int] = []
        for i, token_ids in enumerate(sampled_token_ids):
844
            if token_ids:
845
846
847
848
849
850
851
                # Common case.
                next_token_id = token_ids[-1]
            else:
                # Partial prefill (rare case).
                # Get the next token id from the request state.
                req_id = req_ids[i]
                req_state = requests[req_id]
852
                seq_len = req_state.num_computed_tokens + num_scheduled_tokens[req_id]
853
854
                next_token_id = req_state.get_token_id(seq_len)
            next_token_ids.append(next_token_id)
855
        next_token_ids = torch.tensor(
856
857
            next_token_ids, dtype=torch.int32, device=self.input_ids.device
        )
858
        return next_token_ids
859

860
861
    def prepare_next_token_ids_padded(
        self,
862
        seq_lens_cpu: torch.Tensor,
863
864
865
        sampled_token_ids: torch.Tensor,
        requests: dict[str, CachedRequestState],
        gpu_input_batch: InputBatch,
866
        discard_request_mask: torch.Tensor,
867
    ) -> tuple[torch.Tensor, torch.Tensor]:
868
869
870
871
        """
        This function is used to prepare the inputs for speculative decoding.
        It calculates the next token ids and the number of valid sampled tokens
        for each request, considering the "discarded" requests whose next token
872
873
        is not sampled and comes from `request.get_token_id()` instead. This is denoted
        the "backup" token id. It also counts rejected tokens via `sampled_token_ids`.
874
875
876
        """
        # Precompute get_token_id for when there is no valid next token
        num_reqs = gpu_input_batch.num_reqs
877
        seq_lens_list = seq_lens_cpu[:num_reqs].tolist()
878
879
        self.backup_next_token_ids.np[:num_reqs] = np.array(
            [
880
                requests[gpu_input_batch.req_ids[i]].get_token_id(seq_lens_list[i])
881
                for i in range(num_reqs)
882
883
            ],
            dtype=np.int32,
884
        )
885
        self.backup_next_token_ids.copy_to_gpu(num_reqs)
886
        backup_tokens_gpu = self.backup_next_token_ids.gpu
887

888
889
        batch_size, num_tokens = sampled_token_ids.shape
        device = sampled_token_ids.device
890

891
892
        assert discard_request_mask.dtype == torch.bool
        assert backup_tokens_gpu.dtype == torch.int32
893

894
895
        next_token_ids = torch.empty(batch_size, dtype=torch.int32, device=device)
        valid_sampled_tokens_count = next_token_ids.new_empty(batch_size)
896

897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
        # Kernel grid: one program per request (row)
        grid = (batch_size,)

        # Find the next power of 2 for block sizes
        BLOCK_SIZE_TOKENS = triton.next_power_of_2(num_tokens)
        eagle_prepare_next_token_padded_kernel[grid](
            sampled_token_ids,
            discard_request_mask,
            backup_tokens_gpu,
            next_token_ids,
            valid_sampled_tokens_count,
            gpu_input_batch.vocab_size,
            num_tokens,
            batch_size,
            sampled_token_ids.stride(0),
            BLOCK_SIZE_TOKENS=BLOCK_SIZE_TOKENS,
913
        )
914
915
916

        return next_token_ids, valid_sampled_tokens_count

917
918
919
920
921
    def prepare_inputs_padded(
        self,
        common_attn_metadata: CommonAttentionMetadata,
        spec_decode_metadata: SpecDecodeMetadata,
        valid_sampled_tokens_count: torch.Tensor,
922
    ) -> tuple[CommonAttentionMetadata, torch.Tensor, torch.Tensor]:
923
924
925
926
927
928
        """
        This function is used to prepare the inputs for speculative decoding
        It updates the common_attn_metadata for speculative decoding,
        but does not consider the rejected tokens. Instead, all tokens
        are included as inputs to the speculator, with the rejected tokens
        used as padding and filtered out later by `token_indices_to_sample`.
929
        No blocking CPU operations should be introduced in this function.
930
        """
931
932
933
934
935
        num_reqs = common_attn_metadata.num_reqs
        device = valid_sampled_tokens_count.device

        token_indices_to_sample = torch.empty(
            (num_reqs,), dtype=torch.int32, device=device
936
        )
937
938
939
        num_rejected_tokens_gpu = torch.empty(
            (num_reqs,), dtype=torch.int32, device=device
        )
940

941
942
943
944
945
946
        grid = (num_reqs,)
        eagle_prepare_inputs_padded_kernel[grid](
            spec_decode_metadata.cu_num_draft_tokens,
            valid_sampled_tokens_count,
            common_attn_metadata.query_start_loc,
            token_indices_to_sample,
947
            num_rejected_tokens_gpu,
948
            num_reqs,
949
        )
950
951

        query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
952
        new_query_len_per_req = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
953
954
955
956
957
958
959

        total_num_tokens = query_start_loc_cpu[-1].item()

        spec_common_attn_metadata = CommonAttentionMetadata(
            query_start_loc=common_attn_metadata.query_start_loc,
            seq_lens=common_attn_metadata.seq_lens,
            query_start_loc_cpu=query_start_loc_cpu,
960
961
            _seq_lens_cpu=common_attn_metadata._seq_lens_cpu,
            _num_computed_tokens_cpu=common_attn_metadata._num_computed_tokens_cpu,
962
963
964
            num_reqs=common_attn_metadata.num_reqs,
            num_actual_tokens=total_num_tokens,
            max_query_len=new_query_len_per_req.max().item(),
965
            max_seq_len=common_attn_metadata.max_seq_len,
966
            block_table_tensor=common_attn_metadata.block_table_tensor,
967
            slot_mapping=common_attn_metadata.slot_mapping[:total_num_tokens],
968
            causal=True,
969
            dcp_local_seq_lens=common_attn_metadata.dcp_local_seq_lens,
970
971
        )

972
973
974
975
976
        return (
            spec_common_attn_metadata,
            token_indices_to_sample,
            num_rejected_tokens_gpu,
        )
977

978
979
980
981
982
983
984
985
986
987
    def propose_tree(
        self,
        batch_size: int,
        # [num_tokens, vocab_size]
        logits: torch.Tensor,
        # [num_tokens]
        positions: torch.Tensor,
        # [num_tokens, hidden_size]
        hidden_states: torch.Tensor,
        common_attn_metadata: CommonAttentionMetadata,
988
989
990
        slot_mappings: dict[str, torch.Tensor]
        | list[dict[str, torch.Tensor]]
        | None = None,
991
    ) -> list[torch.Tensor]:
992
        tree_attn_metadata_builder = self.draft_attn_groups[0].get_metadata_builder()
993
        assert isinstance(tree_attn_metadata_builder, TreeAttentionMetadataBuilder)
994

995
        total_num_drafts = self.cu_drafts_per_level[0]
996
997
        level_num_drafts = total_num_drafts
        # Sample a draft token for each child at the tree root level.
998
        num_children = self.child_drafts_per_level[0]
999
1000
1001
        if num_children == 1:
            draft_token_ids = logits.argmax(dim=-1).view(batch_size, -1)
        else:
1002
1003
1004
            draft_token_ids = torch.topk(logits, num_children, dim=-1).indices.view(
                batch_size, -1
            )
1005
1006
1007
1008
        draft_token_ids_list = [draft_token_ids]
        draft_hidden_states = hidden_states.view(batch_size, 1, -1)

        # Initialize empty tensors for concatenation with the level outputs.
1009
1010
1011
1012
1013
1014
1015
1016
1017
        tree_input_ids = torch.empty(
            0, device=self.input_ids.device, dtype=self.input_ids.dtype
        )
        tree_positions = torch.empty(
            0, device=self.positions.device, dtype=self.positions.dtype
        )
        tree_hidden_states = torch.empty(
            0, device=self.hidden_states.device, dtype=self.hidden_states.dtype
        )
1018
1019
        # Precompute the draft token positions.
        flattened_draft_positions = (
1020
1021
            positions.view(batch_size, -1) + self.tree_draft_pos_offsets[:batch_size, :]
        )
1022
        tree_depth = len(self.cu_drafts_per_level)
1023
        for level in range(tree_depth - 1):
1024
1025
            # Get draft positions for RoPE.
            draft_positions = positions + (level + 1)
1026
            exceeds_max_model_len = (positions + total_num_drafts) >= self.max_model_len
1027
1028
            # Mask out the position ids that exceed the max model length.
            # Otherwise, we may get out-of-range error in RoPE.
1029
            draft_positions = torch.where(
1030
1031
1032
                exceeds_max_model_len,
                0,
                draft_positions,
1033
1034
            ).view(batch_size, -1)

1035
1036
            if level_num_drafts > 1:
                # Repeat the positions for each draft at this level.
1037
                draft_positions = draft_positions.repeat_interleave(
1038
1039
                    level_num_drafts, dim=1
                )
1040
1041
1042
1043

            if num_children > 1:
                # Repeat draft hidden states for each child.
                draft_hidden_states = draft_hidden_states.repeat_interleave(
1044
1045
                    num_children, dim=1
                )
1046
1047

            # Concatenate the draft tokens, positions, and hidden states.
1048
1049
            tree_input_ids = torch.cat([tree_input_ids, draft_token_ids], dim=1)
            tree_positions = torch.cat([tree_positions, draft_positions], dim=1)
1050
            tree_hidden_states = torch.cat(
1051
1052
                [tree_hidden_states, draft_hidden_states], dim=1
            )
1053
1054
1055

            # Build new attention metadata for the next level of drafts.
            # This is necessary to support tree attention.
1056
            query_len = total_num_drafts
1057
1058
            common_attn_metadata = replace(
                common_attn_metadata,
1059
                query_start_loc=query_len * self.arange[: batch_size + 1],
1060
1061
1062
1063
1064
                seq_lens=common_attn_metadata.seq_lens + level_num_drafts,
                num_actual_tokens=batch_size * query_len,
                max_query_len=query_len,
            )
            attn_metadata = tree_attn_metadata_builder.build_for_drafting(
1065
                common_attn_metadata=common_attn_metadata, draft_index=level + 1
1066
1067
            )

1068
            # Apply new attention metadata to all draft layers.
1069
            per_layer_attn_metadata = {}
1070
1071
1072
            for attn_group in self.draft_attn_groups:
                for layer_name in attn_group.layer_names:
                    per_layer_attn_metadata[layer_name] = attn_metadata
1073
1074

            # Consider max model length.
1075
1076
1077
            attn_metadata.max_seq_len = min(
                attn_metadata.max_seq_len, self.max_model_len
            )
1078
1079
1080
1081
1082
            # For the requests that exceed the max model length, we set the
            # sequence length to 1 to minimize their overheads in attention.
            attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1)

            # Compute the slot mapping.
1083
            block_size = tree_attn_metadata_builder.kv_cache_spec.block_size
1084
            query_positions = flattened_draft_positions[:, level : level + query_len]
1085
            block_numbers = query_positions // block_size
1086
            block_ids = attn_metadata.block_table.gather(dim=1, index=block_numbers)
1087
            slot_mapping = block_ids * block_size + query_positions % block_size
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
            # Mask out the slot mappings that exceed the max model length.
            # Otherwise, the KV cache will be inadvertently updated with the
            # padding tokens.
            slot_mapping[exceeds_max_model_len] = PADDING_SLOT_ID
            attn_metadata.slot_mapping = slot_mapping.view(-1)

            # Copy inputs to buffer for cudagraph.
            num_tokens = attn_metadata.num_actual_tokens
            input_ids = tree_input_ids.view(-1)
            self.input_ids[:num_tokens] = input_ids
            self.positions[:num_tokens] = tree_positions.view(-1)
1099
            self.hidden_states[:num_tokens] = tree_hidden_states.view(num_tokens, -1)
1100

1101
1102
1103
1104
            cudagraph_runtime_mode, batch_desc = self.cudagraph_dispatcher.dispatch(
                num_tokens
            )
            num_input_tokens = batch_desc.num_tokens
1105
            # Run the model.
1106
            with set_forward_context(
1107
1108
1109
1110
                per_layer_attn_metadata,
                self.vllm_config,
                num_tokens=num_input_tokens,
                cudagraph_runtime_mode=cudagraph_runtime_mode,
1111
1112
1113
                slot_mapping=self._get_slot_mapping(
                    num_input_tokens, attn_metadata.slot_mapping
                ),
1114
            ):
1115
1116
1117
1118
1119
1120
1121
1122
1123
                last_hidden_states, hidden_states = self.model(
                    input_ids=self.input_ids[:num_input_tokens],
                    positions=self.positions[:num_input_tokens],
                    hidden_states=self.hidden_states[:num_input_tokens],
                    inputs_embeds=None,
                )

            # Get the output hidden states for the draft tokens.
            draft_hidden_states = hidden_states[:num_tokens].view(
1124
1125
                batch_size, query_len, -1
            )[:, -level_num_drafts:]
1126
            draft_last_hidden_states = last_hidden_states[:num_tokens].view(
1127
1128
                batch_size, query_len, -1
            )[:, -level_num_drafts:]
1129
1130
1131

            # Get the output logits for the draft tokens.
            logits = self.model.compute_logits(
1132
1133
                draft_last_hidden_states.reshape(batch_size * level_num_drafts, -1)
            )
1134
1135
1136
1137
1138
1139

            # Sample a draft token for each child at the next tree level.
            num_children = self.child_drafts_per_level[level + 1]
            if num_children == 1:
                draft_token_ids = logits.argmax(dim=-1).view(batch_size, -1)
            else:
1140
1141
1142
                draft_token_ids = torch.topk(logits, num_children, dim=-1).indices.view(
                    batch_size, -1
                )
1143
1144
1145
            draft_token_ids_list.append(draft_token_ids)

            # Update the # drafts counters for the next tree level.
1146
            level_num_drafts = self.cu_drafts_per_level[level + 1] - total_num_drafts
1147
1148
1149
            total_num_drafts = self.cu_drafts_per_level[level + 1]
        return draft_token_ids_list

1150
    def prepare_inputs(
1151
1152
        self,
        common_attn_metadata: CommonAttentionMetadata,
1153
1154
        sampled_token_ids: list[list[int]],
        num_draft_tokens: list[int],
1155
1156
    ) -> tuple[CommonAttentionMetadata, torch.Tensor]:
        """
1157
        This function is used to prepare the inputs for speculative decoding.
1158
1159
1160
1161
1162
1163
        It updates to the common_attn_metadata to account for the rejected
        tokens (and newly sampled tokens). It also returns the token indices
        of the tokens that should be fed to the speculator.
        """
        # E.g.
        #  common_attn_metadata.query_start_loc{_cpu}:
1164
        #       [0, q1, q1 + q2, q1 + q2 + q3]
1165
1166
1167
1168
1169
1170
        #  common_attn_metadata.seq_lens{_cpu}: [s1, s2, s3]
        #  num_rejected_tokens: [n1, n2, n3]
        # This function computes the intermediate values:
        #  num_tokens_per_req: [q1 - n1, q2 - n2, q3 - n3]
        # And returns:
        #  common_attn_metadata.query_start_loc{_cpu}:
1171
        #       [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
1172
        #  common_attn_metadata.seq_lens{_cpu}:
1173
        #       [s1 - n1 + 1, s2 - n2 + 1, s3 - n3 + 1]
1174
        #  token_indices: [0, 1, ..., q1 - n1 - 1,
1175
1176
        #                 q1, q1 + 1, ..., q1 + q2 - n2 - 1,
        #                 q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1]
1177

1178
1179
1180
1181
        num_rejected_tokens = [
            n + 1 - len(sampled_token_ids[i]) if n > 0 else 0
            for i, n in enumerate(num_draft_tokens)
        ]
1182
        num_rejected_tokens = torch.tensor(num_rejected_tokens, dtype=torch.int32)
1183

1184
1185
        device = common_attn_metadata.query_start_loc.device
        query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
1186
        new_seq_lens_cpu = common_attn_metadata.seq_lens_cpu - num_rejected_tokens
1187
1188

        # [0, q1, q1 + q2, q1 + q2 + q3] -> [q1, q2, q3]
1189
        new_query_len_per_req = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
1190
1191
1192
1193
1194
1195
1196
1197
        # [q1, q2, q3] -> [q1 - n1, q2 - n2, q3 - n3]
        new_num_tokens_per_req = new_query_len_per_req - num_rejected_tokens
        new_num_tokens_per_req_np = new_num_tokens_per_req.numpy()

        # [q1 - n1, q2 - n2, q3 - n3] ->
        # [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
        new_query_start_loc_cpu = torch.zeros(
            query_start_loc_cpu.shape,
1198
            dtype=torch.int32,
1199
1200
            pin_memory=is_pin_memory_available(),
        )
1201
1202
1203
1204
1205
1206
1207
1208
1209
        new_query_start_loc_np = new_query_start_loc_cpu.numpy()
        np.cumsum(new_num_tokens_per_req_np, out=new_query_start_loc_np[1:])

        total_num_tokens = new_query_start_loc_np[-1]
        # Example assuming num_tokens_per_req_np = [2, 4, 3]
        # this implies that `new_query_start_locs` is:
        # [0, 2, 6, 9] ->
        # [0, 0, 2, 2, 2, 2, 6, 6, 6]
        #  _r1_  ____r2____  ___r3__
1210
1211
1212
        new_query_start_locs_expanded = np.repeat(
            new_query_start_loc_np[:-1], new_num_tokens_per_req_np
        )
1213
1214
1215
        # [0, 1, 2, 3, 4, 5, 6, 7, 8] ->
        # [0, 1, 0, 1, 2, 3, 0, 1, 2]
        #  _r1_  ____r2____  ___r3__
1216
        token_offsets = (
1217
1218
            self.token_arange_np[:total_num_tokens] - new_query_start_locs_expanded
        )
1219
1220
1221
1222
1223
1224

        # Expand starting positions to match token pattern
        # [0, q1, q1 + q2] ->
        # [0, 0, q1, q1, q1, q1, q1 + q2, q1 + q2, q1 + q2]
        #  _r1_  _____r2_______  ___________r3____________
        old_query_start_locs_expanded = np.repeat(
1225
1226
            query_start_loc_cpu[:-1].numpy(), new_num_tokens_per_req_np
        )
1227
        # Final token indices are:
1228
1229
1230
        # [0, 1,                                // req 1
        #  q1 + 0, q1 + 1, q1 + 2, q1 + 3,       // req 2
        #  q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3
1231
        token_indices_np = token_offsets + old_query_start_locs_expanded
1232
        token_indices = torch.from_numpy(token_indices_np).to(device, non_blocking=True)
1233
1234

        spec_common_attn_metadata = CommonAttentionMetadata(
1235
            query_start_loc=new_query_start_loc_cpu.to(device, non_blocking=True),
1236
1237
            seq_lens=new_seq_lens_cpu.to(device, non_blocking=True),
            query_start_loc_cpu=new_query_start_loc_cpu,
1238
1239
            _seq_lens_cpu=new_seq_lens_cpu,
            _num_computed_tokens_cpu=common_attn_metadata._num_computed_tokens_cpu,
1240
1241
1242
            num_reqs=common_attn_metadata.num_reqs,
            num_actual_tokens=total_num_tokens,
            max_query_len=new_query_len_per_req.max().item(),
1243
            max_seq_len=new_seq_lens_cpu.max().item(),
1244
1245
            block_table_tensor=common_attn_metadata.block_table_tensor,
            slot_mapping=common_attn_metadata.slot_mapping[token_indices],
1246
            causal=True,
1247
            dcp_local_seq_lens=common_attn_metadata.dcp_local_seq_lens,
1248
        )
1249
1250

        return spec_common_attn_metadata, token_indices
1251

1252
    def get_model_name(self, model: nn.Module) -> str:
1253
        if hasattr(model, "module"):  # multi-GPU
1254
1255
1256
            model = model.module
        return model.__class__.__name__

1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
    def _create_draft_vllm_config(self) -> VllmConfig:
        """Return a VllmConfig with kernel-level overrides for the proposer.
        Subclasses may override to apply additional config changes.
        """
        spec_cfg = self.speculative_config
        if spec_cfg.moe_backend is not None:
            return replace(
                self.vllm_config,
                kernel_config=replace(
                    self.vllm_config.kernel_config,
                    moe_backend=spec_cfg.moe_backend,
                ),
            )
        return self.vllm_config

1272
1273
1274
1275
1276
1277
1278
    def _get_model(self) -> nn.Module:
        """
        Default method to call get_model(). Can be overridden by subclasses which
        need to customize model loading.
        """
        from vllm.compilation.backends import set_model_tag

1279
        draft_vllm_config = self._create_draft_vllm_config()
1280
1281
        with set_model_tag("eagle_head"):
            model = get_model(
1282
                vllm_config=draft_vllm_config,
1283
                model_config=self.speculative_config.draft_model_config,
1284
                load_config=self.speculative_config.draft_load_config,
1285
1286
1287
            )
        return model

1288
    def load_model(self, target_model: nn.Module) -> None:
1289
        target_attn_layer_names = set(
1290
1291
1292
1293
            get_layers_from_vllm_config(
                self.vllm_config,
                AttentionLayerBase,  # type: ignore[type-abstract]
            ).keys()
1294
        )
1295

1296
        self.model = self._get_model()
1297

1298
1299
1300
1301
        # Find draft layers (attention layers added by draft model)
        all_attn_layers = get_layers_from_vllm_config(
            self.vllm_config,
            AttentionLayerBase,  # type: ignore[type-abstract]
1302
        )
1303
1304
        self._draft_attn_layer_names = (
            set(all_attn_layers.keys()) - target_attn_layer_names
1305
        )
1306

1307
        if self.supports_mm_inputs:
1308
1309
1310
            # Even if the target model is multimodal, we can also use
            # text-only draft models
            try:
1311
                dummy_input_ids = torch.tensor([[1]], device=self.input_ids.device)
1312
                self.model.embed_input_ids(dummy_input_ids, multimodal_embeddings=None)
1313
1314
1315
            except (NotImplementedError, AttributeError, TypeError):
                logger.warning(
                    "Draft model does not support multimodal inputs, "
1316
1317
                    "falling back to text-only mode"
                )
1318
                self.supports_mm_inputs = False
1319

1320
1321
        if supports_multimodal(target_model):
            # handle multimodality
1322
            assert hasattr(target_model, "config")
1323
1324
1325
            if self.get_model_name(target_model) in [
                "Qwen2_5_VLForConditionalGeneration",
                "Qwen3VLForConditionalGeneration",
1326
                "Qwen3VLMoeForConditionalGeneration",
1327
                "HunYuanVLForConditionalGeneration",
1328
                "GlmOcrForConditionalGeneration",
1329
1330
                "Qwen3_5ForConditionalGeneration",
                "Qwen3_5MoeForConditionalGeneration",
1331
            ]:
1332
                self.model.config.image_token_index = target_model.config.image_token_id
1333
1334
1335
1336
            elif self.get_model_name(target_model) == "PixtralForConditionalGeneration":
                self.model.config.image_token_index = (
                    target_model.config.vision_config.image_token_id
                )
1337
1338
1339
1340
            elif self.get_model_name(target_model) == "KimiK25ForConditionalGeneration":
                self.model.config.image_token_index = (
                    target_model.config.media_placeholder_token_id
                )
1341
1342
            else:
                self.model.config.image_token_index = (
1343
1344
                    target_model.config.image_token_index
                )
1345
1346
1347
            target_language_model = cast(
                SupportsMultiModal, target_model
            ).get_language_model()
1348
1349
        else:
            target_language_model = target_model
1350

1351
1352
1353
        self._maybe_share_embeddings(target_language_model)
        self._maybe_share_lm_head(target_language_model)

1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
        if (
            self.parallel_drafting
            and self.pass_hidden_states_to_model
            and self.parallel_drafting_hidden_state_tensor is not None
        ):
            flat_mask = self.model.mask_hidden.view(-1)
            if self.eagle3_use_aux_hidden_state:
                # EAGLE3: mask_hidden stores all aux hidden states,
                # project through combine_hidden_states
                self.parallel_drafting_hidden_state_tensor.copy_(
                    self.model.combine_hidden_states(flat_mask)
1365
                )
1366
1367
            else:
                self.parallel_drafting_hidden_state_tensor.copy_(flat_mask)
1368
1369
1370
1371
1372
1373
1374
1375

    def _maybe_share_embeddings(self, target_language_model: nn.Module) -> None:
        """
        Some draft models may not have their own embedding layers, and some may
        have a duplicate copy of the target model's embedding layers. In these cases,
        we share the target model's embedding layers with the draft model to save
        memory.
        """
1376
        if get_pp_group().world_size == 1:
1377
1378
1379
1380
1381
1382
1383
            inner_model = getattr(target_language_model, "model", None)
            if inner_model is None:
                raise AttributeError("Target model does not have 'model' attribute")
            if hasattr(inner_model, "embed_tokens"):
                target_embed_tokens = inner_model.embed_tokens
            elif hasattr(inner_model, "embedding"):
                target_embed_tokens = inner_model.embedding
1384
1385
            else:
                raise AttributeError(
1386
1387
                    "Target model does not have 'embed_tokens' or 'embedding' attribute"
                )
1388

1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
            share_embeddings = False
            if hasattr(self.model, "has_own_embed_tokens"):
                # EAGLE model
                if not self.model.has_own_embed_tokens:
                    share_embeddings = True
                    logger.info(
                        "Detected EAGLE model without its own embed_tokens in the"
                        " checkpoint. Sharing target model embedding weights with the"
                        " draft model."
                    )
                elif (
                    isinstance(target_embed_tokens.weight, torch.Tensor)
                    and isinstance(self.model.model.embed_tokens.weight, torch.Tensor)
1402
1403
1404
                    # TODO: Offload to CPU for comparison to avoid extra GPU memory
                    # usage in CI testing environments with limited GPU memory
                    and torch.equal(
1405
1406
                        target_embed_tokens.weight.cpu(),
                        self.model.model.embed_tokens.weight.cpu(),
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
                    )
                ):
                    share_embeddings = True
                    logger.info(
                        "Detected EAGLE model with embed_tokens identical to the target"
                        " model. Sharing target model embedding weights with the draft"
                        " model."
                    )
                else:
                    logger.info(
                        "Detected EAGLE model with distinct embed_tokens weights. "
                        "Keeping separate embedding weights from the target model."
                    )
1420
            else:
1421
1422
                # MTP model
                share_embeddings = True
1423
                logger.info(
1424
1425
                    "Detected MTP model. "
                    "Sharing target model embedding weights with the draft model."
1426
                )
1427
1428
1429
1430
1431

            if share_embeddings:
                if hasattr(self.model.model, "embed_tokens"):
                    del self.model.model.embed_tokens
                self.model.model.embed_tokens = target_embed_tokens
1432
        else:
1433
            logger.info(
1434
                "The draft model's vocab embedding will be loaded separately"
1435
1436
                " from the target model."
            )
1437

1438
1439
1440
1441
1442
1443
    def _maybe_share_lm_head(self, target_language_model: nn.Module) -> None:
        """
        Some draft models may not have their own LM head, and some may have a
        duplicate copy of the target model's LM head. In these cases, we share
        the target model's LM head with the draft model to save memory.
        """
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
        share_lm_head = False
        if hasattr(self.model, "has_own_lm_head"):
            # EAGLE model
            if not self.model.has_own_lm_head:
                share_lm_head = True
                logger.info(
                    "Detected EAGLE model without its own lm_head in the checkpoint. "
                    "Sharing target model lm_head weights with the draft model."
                )
            elif (
                hasattr(target_language_model, "lm_head")
1455
1456
                and hasattr(target_language_model.lm_head, "weight")
                and hasattr(self.model.lm_head, "weight")
1457
1458
                and isinstance(target_language_model.lm_head.weight, torch.Tensor)
                and isinstance(self.model.lm_head.weight, torch.Tensor)
1459
1460
                # TODO: Offload to CPU for comparison to avoid extra GPU memory
                # usage in CI testing environments with limited GPU memory
1461
                and torch.equal(
1462
1463
                    target_language_model.lm_head.weight.cpu(),
                    self.model.lm_head.weight.cpu(),
1464
                )
1465
            ):
1466
                share_lm_head = True
1467
                logger.info(
1468
1469
                    "Detected EAGLE model with lm_head identical to the target model. "
                    "Sharing target model lm_head weights with the draft model."
1470
                )
1471
1472
            else:
                logger.info(
1473
1474
                    "Detected EAGLE model with distinct lm_head weights. "
                    "Keeping separate lm_head weights from the target model."
1475
                )
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
        else:
            # MTP model
            share_lm_head = True
            logger.info(
                "Detected MTP model. "
                "Sharing target model lm_head weights with the draft model."
            )

        if share_lm_head and hasattr(target_language_model, "lm_head"):
            if hasattr(self.model, "lm_head"):
                del self.model.lm_head
            self.model.lm_head = target_language_model.lm_head
1488

1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
            # MTP models call compute_logits via shared_head.head (a
            # ParallelLMHead inside each MTP layer), not self.model.lm_head.
            # If the checkpoint omits a copy of the lm_head weights at the
            # MTP layer path, shared_head.head stays uninitialised and
            # produces NaN logits. Always share it explicitly.
            inner = getattr(self.model, "model", None)
            layers = getattr(inner, "layers", None) if inner else None
            if layers is not None:
                items = layers.values() if isinstance(layers, nn.ModuleDict) else layers
                for layer in items:
                    sh = getattr(layer, "shared_head", None)
                    if sh is not None and hasattr(sh, "head"):
                        del sh.head
                        sh.head = target_language_model.lm_head
                        logger.info(
                            "Shared target model lm_head with MTP shared_head.head."
                        )

1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
        if self.use_local_argmax_reduction:
            if not hasattr(self.model, "get_top_tokens"):
                raise ValueError(
                    "use_local_argmax_reduction is enabled but draft model "
                    f"{self.model.__class__.__name__} does not implement "
                    "get_top_tokens()."
                )
            # Warn if draft model has vocab remapping, which forces fallback
            # to the full-logits path (negating the optimization).
            if (
                hasattr(self.model, "draft_id_to_target_id")
                and self.model.draft_id_to_target_id is not None
            ):
                logger.warning(
                    "use_local_argmax_reduction is enabled but draft model "
                    "uses draft_id_to_target_id vocab remapping. The "
                    "optimization will be bypassed (falling back to full "
                    "logits gather + argmax)."
                )
            else:
                logger.info(
                    "Using local argmax reduction for draft token generation "
                    "(communication: O(2*tp_size) vs O(vocab_size))."
                )

1532
1533
1534
1535
    @torch.inference_mode()
    def dummy_run(
        self,
        num_tokens: int,
1536
1537
        use_cudagraphs: bool = True,
        is_graph_capturing: bool = False,
1538
        slot_mappings: dict[str, torch.Tensor] | None = None,
1539
    ) -> None:
Rémi Delacourt's avatar
Rémi Delacourt committed
1540
1541
        # FIXME: when using tree-based specdec, adjust number of forward-passes
        # according to the depth of the tree.
1542
        only_one_forward_pass = is_graph_capturing or self.parallel_drafting
Rémi Delacourt's avatar
Rémi Delacourt committed
1543
        for fwd_idx in range(
1544
            1 if only_one_forward_pass else self.num_speculative_tokens
1545
        ):
Rémi Delacourt's avatar
Rémi Delacourt committed
1546
            if fwd_idx <= 1:
1547
1548
1549
                cudagraph_runtime_mode, num_input_tokens, num_tokens_across_dp = (
                    self._determine_batch_execution_and_padding(
                        num_tokens, use_cudagraphs=use_cudagraphs
1550
                    )
1551
                )
1552

1553
1554
            # Make sure to use EAGLE's own buffer during cudagraph capture.
            if (
1555
                self._draft_attn_layer_names
1556
                and slot_mappings is not None
1557
                and next(iter(self._draft_attn_layer_names)) in slot_mappings
1558
1559
1560
1561
1562
            ):
                slot_mapping_dict = self._get_slot_mapping(num_input_tokens)
            else:
                slot_mapping_dict = slot_mappings or {}

Rémi Delacourt's avatar
Rémi Delacourt committed
1563
1564
1565
1566
1567
            with set_forward_context(
                None,
                self.vllm_config,
                num_tokens=num_input_tokens,
                num_tokens_across_dp=num_tokens_across_dp,
1568
                cudagraph_runtime_mode=cudagraph_runtime_mode,
1569
                slot_mapping=slot_mapping_dict,
Rémi Delacourt's avatar
Rémi Delacourt committed
1570
1571
1572
1573
1574
1575
1576
1577
            ):
                if self.supports_mm_inputs:
                    input_ids = None
                    inputs_embeds = self.inputs_embeds[:num_input_tokens]
                else:
                    input_ids = self.input_ids[:num_input_tokens]
                    inputs_embeds = None

1578
                kwargs = dict(
Rémi Delacourt's avatar
Rémi Delacourt committed
1579
1580
1581
1582
                    input_ids=input_ids,
                    positions=self._get_positions(num_input_tokens),
                    inputs_embeds=inputs_embeds,
                )
1583
1584
1585
                if self.pass_hidden_states_to_model:
                    kwargs["hidden_states"] = self.hidden_states[:num_input_tokens]
                self.model(**kwargs)
1586

1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
    def _get_eagle3_use_aux_hidden_state_from_config(self) -> bool:
        """
        Some eagle3 heads (e.g., nvidia/gpt-oss-120b-Eagle3-v2) do not use auxiliary
        hidden states and directly uses the last layer output just like eagle1.
        They might indicate this by setting "use_aux_hidden_state" to False
        inside the "eagle_config" dict of their hf_config.
        """
        if self.method != "eagle3":
            return False
        # Assume that eagle3 heads use aux hidden states by default
        use_aux_hidden_state = True
        eagle_config = getattr(self.draft_model_config.hf_config, "eagle_config", None)
        if eagle_config is not None:
            use_aux_hidden_state = eagle_config.get("use_aux_hidden_state", True)
        return use_aux_hidden_state

1603
    def validate_same_kv_cache_group(self, kv_cache_config: KVCacheConfig) -> None:
1604
        """
1605
1606
        Validate that all drafting layers belong to the same KVCacheGroup.
        Need this assumption to ensure all drafting layers can use the
1607
1608
1609
1610
1611
1612
1613
        same AttentionMetadata.
        May extend to multiple AttentionMetadata in the future.
        """
        kv_cache_groups: dict[str, int] = {}
        for id, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups):
            for layer_name in kv_cache_group.layer_names:
                kv_cache_groups[layer_name] = id
1614
1615
1616
1617
1618
        assert (
            len(
                set(
                    [
                        kv_cache_groups[layer_name]
1619
                        for layer_name in self._draft_attn_layer_names
1620
1621
1622
1623
                    ]
                )
            )
            == 1
1624
        ), "All drafting layers should belong to the same kv cache group"
1625

1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
    def initialize_attn_backend(
        self,
        kv_cache_config: KVCacheConfig,
        kernel_block_sizes: list[int] | None = None,
    ) -> None:
        """
        Initialize AttentionGroups for draft layers using kv_cache_config.
        Called from the model runner's initialize_metadata_builders.
        """
        all_attn_layers = get_layers_from_vllm_config(
            self.vllm_config,
            AttentionLayerBase,  # type: ignore[type-abstract]
        )

        # Find which kv_cache_group the draft layers belong to
        self.validate_same_kv_cache_group(kv_cache_config)
        kv_cache_spec = None
        for gid, group in enumerate(kv_cache_config.kv_cache_groups):
            if self._draft_attn_layer_names & set(group.layer_names):
                self.kv_cache_gid = gid
                kv_cache_spec = group.kv_cache_spec
                break

        attention_groups: dict[tuple[str, str], AttentionGroup] = {}
        if kv_cache_spec is not None:
            for layer_name in self._draft_attn_layer_names:
                attn_backend = all_attn_layers[layer_name].get_attn_backend()
                backend_key = attn_backend.full_cls_name()
                if backend_key not in attention_groups:
                    layer_kv_cache_spec = kv_cache_spec
                    if isinstance(layer_kv_cache_spec, UniformTypeKVCacheSpecs):
                        layer_kv_cache_spec = layer_kv_cache_spec.kv_cache_specs[
                            layer_name
                        ]

                    kernel_block_size = (
                        kernel_block_sizes[self.kv_cache_gid]
                        if kernel_block_sizes is not None
                        and self.kv_cache_gid < len(kernel_block_sizes)
                        else None
                    )
                    attn_group = AttentionGroup(
                        backend=attn_backend,
                        layer_names=[layer_name],
                        kv_cache_spec=layer_kv_cache_spec,
                        kv_cache_group_id=self.kv_cache_gid,
                    )
                    attn_group.create_metadata_builders(
                        self.vllm_config,
                        self.device,
                        kernel_block_size=kernel_block_size,
                    )
                    attention_groups[backend_key] = attn_group
                else:
                    attention_groups[backend_key].layer_names.append(layer_name)

        self.draft_attn_groups = list(attention_groups.values())
1683
1684
1685
1686
        self.block_size = (
            self.draft_attn_groups[0].get_metadata_builder().kv_cache_spec.block_size
        )
        logger.debug("Using block size %d for drafting layers", self.block_size)
1687

1688
    def _determine_batch_execution_and_padding(
Rémi Delacourt's avatar
Rémi Delacourt committed
1689
        self,
1690
1691
1692
1693
1694
1695
        num_tokens: int,
        use_cudagraphs: bool = True,
    ) -> tuple[CUDAGraphMode, int, torch.Tensor | None]:
        cudagraph_mode, batch_desc = self.cudagraph_dispatcher.dispatch(
            num_tokens,
            valid_modes=({CUDAGraphMode.NONE} if not use_cudagraphs else None),
Rémi Delacourt's avatar
Rémi Delacourt committed
1696
        )
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
        num_tokens_padded = batch_desc.num_tokens

        # Extra coordination when running data-parallel since we need to
        # coordinate across ranks
        # TODO(Flechman): support DBO ubatching
        should_ubatch, num_tokens_across_dp = False, None
        if self.vllm_config.parallel_config.data_parallel_size > 1:
            should_ubatch, num_tokens_across_dp, synced_cudagraph_mode = (
                coordinate_batch_across_dp(
                    num_tokens_unpadded=num_tokens,
                    parallel_config=self.vllm_config.parallel_config,
                    allow_microbatching=False,
                    num_tokens_padded=num_tokens_padded,
                    cudagraph_mode=cudagraph_mode.value,
                )
            )
            assert not should_ubatch, "DBO ubatching not implemented for EAGLE"

            # Extract DP-synced values
            if num_tokens_across_dp is not None:
                dp_rank = self.dp_rank
                num_tokens_padded = int(num_tokens_across_dp[dp_rank].item())
                # Re-dispatch with DP padding so we have the correct
                # batch_descriptor
                cudagraph_mode, batch_desc = self.cudagraph_dispatcher.dispatch(
                    num_tokens_padded,
                    valid_modes={CUDAGraphMode(synced_cudagraph_mode)},
                )
                # Assert to make sure the agreed upon token count is correct
                # otherwise num_tokens_across_dp will no-longer be valid
                assert batch_desc.num_tokens == num_tokens_padded
                num_tokens_across_dp[dp_rank] = num_tokens_padded
Rémi Delacourt's avatar
Rémi Delacourt committed
1729

1730
        return cudagraph_mode, num_tokens_padded, num_tokens_across_dp
Rémi Delacourt's avatar
Rémi Delacourt committed
1731

1732

1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
class EagleProposer(SpecDecodeBaseProposer):
    def __init__(
        self,
        vllm_config: VllmConfig,
        device: torch.device,
        runner=None,
    ):
        super().__init__(
            vllm_config,
            device,
            pass_hidden_states_to_model=True,
            runner=runner,
        )


1748
1749
1750
1751
# NOTE(woosuk): Currently, the below code is not used and we always use argmax
# to sample the draft tokens. We will use this after we find a way to manage
# the draft prob tensor.
# Refer to https://github.com/vllm-project/vllm/pull/16899 for the details.
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
# FIXME(woosuk): The logic here is duplicated with the main sampling code.
# We should refactor this to reuse the same sampling implementation.
def compute_probs_and_sample_next_token(
    logits: torch.Tensor,
    sampling_metadata: SamplingMetadata,
) -> tuple[torch.Tensor, torch.Tensor]:
    if sampling_metadata.all_greedy:
        # For greedy requests, draft_probs is not used in rejection sampling.
        # Therefore, we can just return the logits.
        probs = logits
        next_token_ids = logits.argmax(dim=-1)
        return next_token_ids, probs

1765
1766
1767
1768
1769
1770
1771
1772
1773
    assert sampling_metadata.temperature is not None

    # Use epsilon comparison to detect greedy sampling (temperature ~ 0.0)
    # consistent with sampler.py's _SAMPLING_EPS threshold
    temperature = sampling_metadata.temperature
    # Avoid division by zero if there are greedy requests.
    if not sampling_metadata.all_random:
        is_greedy = temperature < _SAMPLING_EPS
        temperature = torch.where(is_greedy, 1.0, temperature)
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
    logits.div_(temperature.view(-1, 1))
    probs = logits.softmax(dim=-1, dtype=torch.float32)

    # NOTE(woosuk): Currently, we ignore most of the sampling parameters in
    # generating the draft tokens. We only use the temperature. While this
    # could degrade the acceptance rate, it does not affect the distribution
    # of the generated tokens after rejection sampling.

    # TODO(woosuk): Consider seeds.
    q = torch.empty_like(probs)
    q.exponential_()
1785
1786
1787
    # NOTE(woosuk): We shouldn't use `probs.div_(q)` because the draft_probs
    # will be used later for rejection sampling.
    next_token_ids = probs.div(q).argmax(dim=-1).view(-1)
1788
1789
    if not sampling_metadata.all_random:
        greedy_token_ids = probs.argmax(dim=-1)
1790
        next_token_ids = torch.where(is_greedy, greedy_token_ids, next_token_ids)
1791
    return next_token_ids, probs