eagle.py 76.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
import ast
from dataclasses import replace
5
from importlib.util import find_spec
6
from typing import cast
7

8
import numpy as np
9
10
11
import torch
import torch.nn as nn

12
13
14
15
16
from vllm.config import (
    CUDAGraphMode,
    VllmConfig,
    get_layers_from_vllm_config,
)
17
from vllm.distributed.parallel_state import get_pp_group
18
from vllm.forward_context import set_forward_context
19
from vllm.logger import init_logger
20
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
21
from vllm.model_executor.model_loader import get_model
22
from vllm.model_executor.models import supports_multimodal
23
from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache
24
from vllm.model_executor.models.interfaces import SupportsMultiModal
25
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
26
from vllm.multimodal import MULTIMODAL_REGISTRY
27
from vllm.platforms import current_platform
28
from vllm.triton_utils import triton
29
from vllm.utils.platform_utils import is_pin_memory_available
30
31
32
33
from vllm.v1.attention.backend import (
    AttentionMetadataBuilder,
    CommonAttentionMetadata,
)
34
from vllm.v1.attention.backends.registry import AttentionBackendEnum
35
36
37
38
from vllm.v1.attention.backends.tree_attn import (
    TreeAttentionMetadata,
    TreeAttentionMetadataBuilder,
)
39
from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata
40
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
41
from vllm.v1.kv_cache_interface import KVCacheConfig
42
from vllm.v1.sample.metadata import SamplingMetadata
43
from vllm.v1.sample.sampler import _SAMPLING_EPS
44
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
45
from vllm.v1.spec_decode.utils import (
46
47
48
    PADDING_SLOT_ID,
    compute_new_slot_mapping,
    copy_and_expand_eagle_inputs_kernel,
49
50
    eagle_prepare_inputs_padded_kernel,
    eagle_prepare_next_token_padded_kernel,
51
    extend_all_queries_by_N,
52
)
53
from vllm.v1.utils import CpuGpuBuffer
Rémi Delacourt's avatar
Rémi Delacourt committed
54
from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
55
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
56

57
58
logger = init_logger(__name__)

59

60
class SpecDecodeBaseProposer:
61
62
63
64
    def __init__(
        self,
        vllm_config: VllmConfig,
        device: torch.device,
65
        pass_hidden_states_to_model: bool,
Jiayi Yao's avatar
Jiayi Yao committed
66
        runner=None,
67
68
    ):
        self.vllm_config = vllm_config
69
        assert vllm_config.speculative_config is not None
70
71
72
        self.speculative_config = vllm_config.speculative_config
        self.draft_model_config = self.speculative_config.draft_model_config
        self.method = self.speculative_config.method
73
        self.pass_hidden_states_to_model = pass_hidden_states_to_model
74

Jiayi Yao's avatar
Jiayi Yao committed
75
        self.runner = runner
76
        self.device = device
77
        self.dtype = vllm_config.model_config.dtype
78
        self.max_model_len = vllm_config.model_config.max_model_len
Rémi Delacourt's avatar
Rémi Delacourt committed
79
        self.dp_rank = vllm_config.parallel_config.data_parallel_rank
80
        self.num_speculative_tokens = self.speculative_config.num_speculative_tokens
81

82
83
84
85
        # We need to get the hidden size from the draft model config because
        # the draft model's hidden size can be different from the target model's
        # hidden size (e.g., Llama 3.3 70B).
        self.hidden_size = self.draft_model_config.get_hidden_size()
86
        self.inputs_embeds_size = self.draft_model_config.get_inputs_embeds_size()
87

88
89
90
91
92
93
94
95
96
97
98
99
100
101
        # Unifying eagle, draft model, and parallel drafting support
        self.parallel_drafting: bool = self.speculative_config.parallel_drafting
        self.extra_slots_per_request = (
            1 if not self.parallel_drafting else self.num_speculative_tokens
        )
        self.net_num_new_slots_per_request = self.extra_slots_per_request - (
            1 if self.pass_hidden_states_to_model else 0
        )
        self.needs_extra_input_slots = self.net_num_new_slots_per_request > 0

        self.parallel_drafting_token_id: int = 0
        self.parallel_drafting_hidden_state_tensor: torch.Tensor | None = None
        if self.parallel_drafting:
            self._init_parallel_drafting_params()
102
103
104
        self.use_local_argmax_reduction: bool = (
            self.speculative_config.use_local_argmax_reduction
        )
105
106

        max_batch_size = vllm_config.scheduler_config.max_num_seqs
107
        self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
108
109
        self.token_arange_np = np.arange(self.max_num_tokens)

110
111
112
        # Multi-modal data support
        self.mm_registry = MULTIMODAL_REGISTRY
        self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
113
            vllm_config.model_config
114
        )
115

116
117
        self.attn_metadata_builder: AttentionMetadataBuilder | None = None
        self.draft_indexer_metadata_builder: AttentionMetadataBuilder | None = None
118
119
        self.attn_layer_names: list[str] = []
        self.indexer_layer_names: list[str] = []
120
121
122
        self.eagle3_use_aux_hidden_state: bool = (
            self._get_eagle3_use_aux_hidden_state_from_config()
        )
123

124
        self.compilation_config = self.vllm_config.compilation_config
125
126
127
128
129
130

        # Cudagraph dispatcher for PIECEWISE-only dispatching in eagle.
        # Keys are initialized later via initialize_cudagraph_keys() called from
        # gpu_model_runner._check_and_update_cudagraph_mode after
        # adjust_cudagraph_sizes_for_spec_decode is called.
        self.cudagraph_dispatcher = CudagraphDispatcher(self.vllm_config)
131

132
        # persistent buffers for cuda graph
133
134
135
        self.input_ids = torch.zeros(
            self.max_num_tokens, dtype=torch.int32, device=device
        )
136
137
138
        # Use draft model's M-RoPE setting, not target model's
        # Draft models may be text-only even if target is multimodal
        self.uses_mrope = self.draft_model_config.uses_mrope
139
140
        self.uses_xdrope_dim = self.vllm_config.model_config.uses_xdrope_dim
        self.draft_uses_xdrope_dim = self.draft_model_config.uses_xdrope_dim
141
        if self.uses_mrope:
142
143
144
145
146
147
148
149
150
151
            # NOTE: `mrope_positions` is implemented with one additional dummy
            # position on purpose to make it non-contiguous so that it can work
            # with torch compile.
            # See detailed explanation in https://github.com/vllm-project/vllm/pull/12128#discussion_r1926431923

            # NOTE: When M-RoPE is enabled, position ids are 3D regardless of
            # the modality of inputs. For text-only inputs, each dimension has
            # identical position IDs, making M-RoPE functionally equivalent to
            # 1D-RoPE.
            # See page 5 of https://arxiv.org/abs/2409.12191
152
            self.mrope_positions = torch.zeros(
153
                (3, self.max_num_tokens + 1), dtype=torch.int64, device=device
154
            )
155
156
157
158
159
160
        elif self.uses_xdrope_dim > 0 and self.draft_uses_xdrope_dim > 0:
            self.xdrope_positions = torch.zeros(
                (self.uses_xdrope_dim, self.max_num_tokens + 1),
                dtype=torch.int64,
                device=device,
            )
161
162
        else:
            # RoPE need (max_num_tokens,)
163
164
165
            self.positions = torch.zeros(
                self.max_num_tokens, dtype=torch.int64, device=device
            )
166
        self.hidden_states = torch.zeros(
167
168
            (self.max_num_tokens, self.hidden_size), dtype=self.dtype, device=device
        )
169

170
171
172
        # We need +1 here because the arange is used to set query_start_loc,
        # which has one more element than batch_size.
        max_num_slots_for_arange = max(max_batch_size + 1, self.max_num_tokens)
173
174
175
        self.arange = torch.arange(
            max_num_slots_for_arange, device=device, dtype=torch.int32
        )
176

177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
        if self.needs_extra_input_slots:
            self._raise_if_padded_drafter_batch_disabled()
            self._raise_if_multimodal()
            self._raise_if_mrope()

        self.is_rejected_token_mask: torch.Tensor | None = None
        self.is_masked_token_mask: torch.Tensor | None = None
        if self.needs_extra_input_slots:
            # For draft models and parallel drafting, we need to keep track of
            # which tokens are rejected to update the slot mapping with padding slots.
            self.is_rejected_token_mask = torch.zeros(
                (self.max_num_tokens,), dtype=torch.bool, device=device
            )
            # For parallel drafting, we also need to keep track of which tokens
            # are parallel-padding tokens used to sample at later positions.
            # We populate this tensor even when using draft models for simplicity.
            self.is_masked_token_mask = torch.zeros(
                (self.max_num_tokens,), dtype=torch.bool, device=device
            )

197
        self.inputs_embeds = torch.zeros(
198
199
200
            (self.max_num_tokens, self.inputs_embeds_size),
            dtype=self.dtype,
            device=device,
201
        )
202

203
204
205
206
207
        self.backup_next_token_ids = CpuGpuBuffer(
            max_batch_size,
            dtype=torch.int32,
            pin_memory=is_pin_memory_available(),
            device=device,
208
209
            with_numpy=True,
        )
210

211
212
213
214
        self._slot_mapping_buffer = torch.zeros(
            self.max_num_tokens, dtype=torch.int64, device=device
        )

215
        # Determine allowed attention backends once during initialization.
216
        self.allowed_attn_types: tuple | None = None
217
        if current_platform.is_rocm():
218
219
220
221
222
223
            from vllm.v1.attention.backends.rocm_attn import RocmAttentionMetadata

            rocm_types = [
                TritonAttentionMetadata,
                RocmAttentionMetadata,
            ]
224
            # ROCM_AITER_FA is an optional backend
225
226
227
228
229
230
231
            # We check is_enabled() here to avoid importing the backend module during
            # auto-discovery when VLLM_ROCM_USE_AITER=0, which would trigger aiter
            # import and JIT compilation warnings. Explicit backend selection via
            # attention_config still works because the backend module is loaded
            # directly when selected, not through this auto-discovery path.
            # Check if backend module exists to allow explicit selection
            if find_spec(
232
233
                AttentionBackendEnum.ROCM_AITER_FA.get_path(include_classname=False)
            ):
234
                from vllm.v1.attention.backends.rocm_aiter_fa import (
235
236
237
                    AiterFlashAttentionMetadata,
                )

238
                rocm_types.append(AiterFlashAttentionMetadata)
239
240

            # TRITON_MLA backend support for MLA models (e.g., DeepSeek)
241
242
243
            from vllm.model_executor.layers.attention.mla_attention import (
                MLACommonMetadata,
            )
244
245
246

            rocm_types.append(MLACommonMetadata)

247
248
249
250
251
            # FlexAttention backend support
            from vllm.v1.attention.backends.flex_attention import FlexAttentionMetadata

            rocm_types.append(FlexAttentionMetadata)

252
253
            self.allowed_attn_types = tuple(rocm_types)

254
255
        # Parse the speculative token tree.
        spec_token_tree = self.speculative_config.speculative_token_tree
256
        assert spec_token_tree is not None
257
        self.tree_choices: list[tuple[int, ...]] = ast.literal_eval(spec_token_tree)
258
259
260
261
262
263
264
265
        tree_depth = len(self.tree_choices[-1])
        # Precompute per-level properties of the tree.
        num_drafts_per_level = [0] * tree_depth
        for node in self.tree_choices:
            num_drafts_per_level[len(node) - 1] += 1
        self.cu_drafts_per_level = [num_drafts_per_level[0]]
        self.child_drafts_per_level = [num_drafts_per_level[0]]
        for level in range(1, tree_depth):
266
267
268
269
270
271
            self.cu_drafts_per_level.append(
                self.cu_drafts_per_level[-1] + num_drafts_per_level[level]
            )
            self.child_drafts_per_level.append(
                num_drafts_per_level[level] // num_drafts_per_level[level - 1]
            )
272
273
        # Precompute draft position offsets in flattened tree.
        self.tree_draft_pos_offsets = torch.arange(
274
            1, len(self.tree_choices) + 1, device=device, dtype=torch.int32
275
276
        ).repeat(max_batch_size, 1)

277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
    def _raise_if_padded_drafter_batch_disabled(self):
        if self.speculative_config.disable_padded_drafter_batch:
            raise NotImplementedError(
                "Speculative Decoding with draft models or parallel drafting only "
                "supports padded drafter batch. Please unset "
                "disable_padded_drafter_batch in the speculative_config."
            )

    def _raise_if_multimodal(self):
        if self.supports_mm_inputs:
            raise NotImplementedError(
                "Speculative Decoding with draft models or parallel drafting "
                "does not support multimodal models yet"
            )

    def _raise_if_mrope(self):
        if self.draft_model_config.uses_mrope:
            raise NotImplementedError(
                "Speculative Decoding with draft models or parallel drafting "
                "does not support M-RoPE yet"
            )

    def _init_parallel_drafting_params(self):
        # For parallel drafting, we need the token ID to use for masked slots
        # And for EAGLE + parallel drafting, we need the hidden state tensor to use
        # for those masked slots.

        model_hf_config = self.draft_model_config.hf_config
        if hasattr(model_hf_config, "pard_token"):
            self.parallel_drafting_token_id = model_hf_config.pard_token
        elif hasattr(model_hf_config, "ptd_token_id"):
            self.parallel_drafting_token_id = model_hf_config.ptd_token_id
        else:
            raise ValueError(
                "For parallel drafting, the draft model config must have "
                "`pard_token` or `ptd_token_id` specified in its config.json."
            )

        if self.pass_hidden_states_to_model:
            self.parallel_drafting_hidden_state_tensor = torch.empty(
                self.hidden_size, dtype=self.dtype, device=self.device
            )

320
321
322
    def _get_positions(self, num_tokens: int):
        if self.uses_mrope:
            return self.mrope_positions[:, :num_tokens]
323
324
        if self.uses_xdrope_dim > 0 and self.draft_uses_xdrope_dim > 0:
            return self.xdrope_positions[:, :num_tokens]
325
326
327
328
329
        return self.positions[:num_tokens]

    def _set_positions(self, num_tokens: int, positions: torch.Tensor):
        if self.uses_mrope:
            self.mrope_positions[:, :num_tokens] = positions
330
331
        elif self.uses_xdrope_dim > 0 and self.draft_uses_xdrope_dim > 0:
            self.xdrope_positions[:, :num_tokens] = positions
332
        else:
333
334
335
336
337
            # Convert M-RoPE positions if target model uses M-RoPE
            # but draft doesn't, For text inputs, all M-RoPE
            # dimensions are identical
            if self.vllm_config.model_config.uses_mrope:
                positions = positions[0]
338
339
            self.positions[:num_tokens] = positions

340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
    def _get_slot_mapping(
        self,
        num_tokens: int,
        slot_mapping: torch.Tensor | None = None,
    ) -> dict[str, torch.Tensor]:
        """Return slot_mapping dict for EAGLE layers.

        If slot_mapping is provided, copies it into the buffer first.
        """
        if slot_mapping is not None:
            num_actual = slot_mapping.shape[0]
            self._slot_mapping_buffer[:num_actual].copy_(slot_mapping)
            if num_tokens > num_actual:
                self._slot_mapping_buffer[num_actual:num_tokens].fill_(PADDING_SLOT_ID)

        view = self._slot_mapping_buffer[:num_tokens]
        return {name: view for name in self.attn_layer_names + self.indexer_layer_names}

358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
    def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode) -> None:
        """Initialize cudagraph dispatcher keys for eagle.

        Eagle only supports PIECEWISE cudagraphs (via mixed_mode).
        This should be called after adjust_cudagraph_sizes_for_spec_decode.
        """
        if (
            not self.speculative_config.enforce_eager
            and cudagraph_mode.mixed_mode()
            in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL]
        ):
            eagle_cudagraph_mode = CUDAGraphMode.PIECEWISE
        else:
            eagle_cudagraph_mode = CUDAGraphMode.NONE

        self.cudagraph_dispatcher.initialize_cudagraph_keys(eagle_cudagraph_mode)

375
376
377
378
379
380
    def _greedy_sample(self, hidden_states: torch.Tensor) -> torch.Tensor:
        """Greedy-sample draft tokens from hidden states."""
        if self.use_local_argmax_reduction:
            return self.model.get_top_tokens(hidden_states)
        return self.model.compute_logits(hidden_states).argmax(dim=-1)

381
382
383
384
    def propose(
        self,
        # [num_tokens]
        target_token_ids: torch.Tensor,
385
        # [num_tokens] or [3, num_tokens] when M-RoPE is enabled
386
387
388
389
390
        target_positions: torch.Tensor,
        # [num_tokens, hidden_size]
        target_hidden_states: torch.Tensor,
        # [batch_size]
        next_token_ids: torch.Tensor,
391
        token_indices_to_sample: torch.Tensor | None,
392
        common_attn_metadata: CommonAttentionMetadata,
393
        sampling_metadata: SamplingMetadata,
394
        mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None,
395
        num_rejected_tokens_gpu: torch.Tensor | None = None,
396
397
398
        slot_mappings: dict[str, torch.Tensor]
        | list[dict[str, torch.Tensor]]
        | None = None,
399
    ) -> torch.Tensor:
400
        batch_size = common_attn_metadata.batch_size()
401

402
403
404
        if self.method == "eagle3":
            assert isinstance(self.model, Eagle3LlamaForCausalLM)
            target_hidden_states = self.model.combine_hidden_states(
405
406
                target_hidden_states
            )
407
            assert target_hidden_states.shape[-1] == self.hidden_size
408

409
        num_tokens, token_indices_to_sample, common_attn_metadata = (
410
411
412
413
            self.set_inputs_first_pass(
                target_token_ids=target_token_ids,
                next_token_ids=next_token_ids,
                target_positions=target_positions,
414
415
                target_hidden_states=target_hidden_states,
                token_indices_to_sample=token_indices_to_sample,
416
417
418
419
                cad=common_attn_metadata,
                num_rejected_tokens_gpu=num_rejected_tokens_gpu,
            )
        )
420

421
        assert self.runner is not None
Jiayi Yao's avatar
Jiayi Yao committed
422

423
424
425
426
427
        if self.attn_metadata_builder is None:
            attn_metadata_builder = self._get_attention_metadata_builder()
        else:
            attn_metadata_builder = self.attn_metadata_builder

428
        attn_metadata = attn_metadata_builder.build_for_drafting(
429
430
            common_attn_metadata=common_attn_metadata, draft_index=0
        )
431
432
433
434
435
436
        # FIXME: support hybrid kv for draft model (remove separate indexer)
        if self.draft_indexer_metadata_builder:
            draft_indexer_metadata = (
                self.draft_indexer_metadata_builder.build_for_drafting(
                    common_attn_metadata=common_attn_metadata,
                    draft_index=0,
437
438
                )
            )
439
440
        else:
            draft_indexer_metadata = None
441
442
443
444
445
        # At this moment, we assume all eagle layers belong to the same KV
        # cache group, thus using the same attention metadata.
        per_layer_attn_metadata = {}
        for layer_name in self.attn_layer_names:
            per_layer_attn_metadata[layer_name] = attn_metadata
446

447
448
449
450
        for layer_name in self.indexer_layer_names:
            assert draft_indexer_metadata is not None
            per_layer_attn_metadata[layer_name] = draft_indexer_metadata

451
452
        cudagraph_runtime_mode, num_input_tokens, num_tokens_across_dp = (
            self._determine_batch_execution_and_padding(num_tokens)
Rémi Delacourt's avatar
Rémi Delacourt committed
453
454
        )

455
456
457
        if self.supports_mm_inputs:
            mm_embeds, is_mm_embed = mm_embed_inputs or (None, None)

458
            self.inputs_embeds[:num_tokens] = self.model.embed_input_ids(
459
460
461
                self.input_ids[:num_tokens],
                multimodal_embeddings=mm_embeds,
                is_multimodal=is_mm_embed,
462
            )
463

464
            input_ids = None
465
            inputs_embeds = self.inputs_embeds[:num_input_tokens]
466
467
        else:
            input_ids = self.input_ids[:num_input_tokens]
468
            inputs_embeds = None
469

470
471
472
473
474
475
476
477
        model_kwargs = {
            "input_ids": input_ids,
            "positions": self._get_positions(num_input_tokens),
            "inputs_embeds": inputs_embeds,
        }
        if self.pass_hidden_states_to_model:
            model_kwargs["hidden_states"] = self.hidden_states[:num_input_tokens]

478
        with set_forward_context(
479
480
481
            per_layer_attn_metadata,
            self.vllm_config,
            num_tokens=num_input_tokens,
Rémi Delacourt's avatar
Rémi Delacourt committed
482
            num_tokens_across_dp=num_tokens_across_dp,
483
            cudagraph_runtime_mode=cudagraph_runtime_mode,
484
485
486
            slot_mapping=self._get_slot_mapping(
                num_input_tokens, common_attn_metadata.slot_mapping
            ),
487
        ):
488
489
            ret_hidden_states = self.model(**model_kwargs)
            if not self.model_returns_tuple():
Jiayi Yao's avatar
Jiayi Yao committed
490
                last_hidden_states = ret_hidden_states
491
                hidden_states = last_hidden_states
Jiayi Yao's avatar
Jiayi Yao committed
492
493
            else:
                last_hidden_states, hidden_states = ret_hidden_states
494

495
        sample_hidden_states = last_hidden_states[token_indices_to_sample]
496
497

        # Early exit if there is only one draft token to be generated.
498
        if self.num_speculative_tokens == 1 or self.parallel_drafting:
499
            draft_token_ids = self._greedy_sample(sample_hidden_states)
500
            return draft_token_ids.view(-1, self.num_speculative_tokens)
501

502
        if self.uses_mrope:
503
            positions = self.mrope_positions[:, token_indices_to_sample]
504
        else:
505
            positions = self.positions[token_indices_to_sample]
506
507
508
509
510
511
        if self.method in (
            "deepseek_mtp",
            "ernie_mtp",
            "longcat_flash_mtp",
            "pangu_ultra_moe_mtp",
        ):
512
            hidden_states = self.hidden_states[token_indices_to_sample]
XuruiYang's avatar
XuruiYang committed
513
        else:
514
            hidden_states = hidden_states[token_indices_to_sample]
515
516

        if isinstance(attn_metadata, TreeAttentionMetadata):
517
518
            # Draft using tree attention - requires full logits for top-k
            logits = self.model.compute_logits(sample_hidden_states)
519
520
521
522
523
524
            draft_token_ids_list = self.propose_tree(
                batch_size=batch_size,
                logits=logits,
                positions=positions,
                hidden_states=hidden_states,
                common_attn_metadata=common_attn_metadata,
525
                slot_mappings=slot_mappings,
526
527
528
529
            )
            # [batch_size, num_tree_tokens]
            return torch.cat(draft_token_ids_list, dim=1)

530
        draft_token_ids = self._greedy_sample(sample_hidden_states)
531

532
533
534
        if self.allowed_attn_types is not None and not isinstance(
            attn_metadata, self.allowed_attn_types
        ):
535
536
537
538
            raise ValueError(
                f"Unsupported attention metadata type for speculative "
                "decoding with num_speculative_tokens > 1: "
                f"{type(attn_metadata)}. Supported types are: "
539
540
                f"{self.allowed_attn_types}"
            )
541

542
543
544
        # Generate the remaining draft tokens.
        draft_token_ids_list = [draft_token_ids]

545
546
        cudagraph_runtime_mode, input_batch_size, batch_size_across_dp = (
            self._determine_batch_execution_and_padding(batch_size)
547
        )
548
549
550

        common_attn_metadata.num_actual_tokens = batch_size
        common_attn_metadata.max_query_len = 1
551
        common_attn_metadata.query_start_loc = self.arange[: batch_size + 1]
552
        common_attn_metadata.query_start_loc_cpu = torch.from_numpy(
553
554
            self.token_arange_np[: batch_size + 1]
        ).clone()
555
556
557
558
559
560
561
562
563
564
565

        # In padded drafter batch, we need to adjust the sequence lengths
        # to remove the "padding" (i.e. rejected tokens).
        # Only apply this adjustment when we have rejected tokens
        # (i.e., not the first proposal).
        if self.num_speculative_tokens > 1 and num_rejected_tokens_gpu is not None:
            common_attn_metadata.seq_lens -= num_rejected_tokens_gpu
            # Invalidate the CPU-side shadows to avoid H<>D sync.
            common_attn_metadata._seq_lens_cpu = None
            common_attn_metadata._num_computed_tokens_cpu = None

566
        for token_index in range(self.num_speculative_tokens - 1):
567
            # Update the inputs.
568
569
570
            # cast to int32 is crucial when eagle model is compiled.
            # tensor.argmax() returns int64 by default.
            input_ids = draft_token_ids_list[-1].int()
571
572
573
574
575
576
577
578
579
580
581
582
583
            if self.uses_mrope:
                positions += 1
                # NOTE(woosuk): We should handle the case where the draft model
                # generates tokens beyond the max model length.
                # Since it is complex to remove such requests from the batch,
                # we keep them in the batch but adjust the position ids
                # and slot mappings to avoid the
                # out-of-range access during the model execution.
                # The draft tokens generated with this adjustment
                # should be ignored.
                exceeds_max_model_len = positions[0] >= self.max_model_len
                # Mask out the position ids that exceed the max model length.
                # Otherwise, we may get out-of-range error in RoPE.
584
585
586
587
588
                clamped_positions = torch.where(
                    exceeds_max_model_len.unsqueeze(0),
                    torch.zeros_like(positions),
                    positions,
                )
589
590
591
            else:
                positions += 1
                exceeds_max_model_len = positions >= self.max_model_len
592
                clamped_positions = torch.where(exceeds_max_model_len, 0, positions)
593
594
595
            # For data integrity when async scheduling, we shouldn't use in place
            # operations in case they are modified in next step's `prepare_input`
            # of main model.
596
            # Increment the sequence lengths.
597
            common_attn_metadata.seq_lens += 1
598
599
            # For the requests that exceed the max model length, we set the
            # sequence length to 1 to minimize their overheads in attention.
600
            common_attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1)
601
602
603
604
605
606
            # Increment the maximum sequence length. We increment max_seq_len
            # unconditionally even though some seq_lens may have been capped above,
            # as max_seq_len serves as an upper bound for sequence lengths.
            common_attn_metadata.max_seq_len = min(
                common_attn_metadata.max_seq_len + 1, self.max_model_len
            )
607

608
609
610
611
612
613
            # Also update the CPU-side shadow; NOTE: this is hacky and should be
            # removed in when common_attn_metadata.seq_lens_cpu is deprecated.
            if common_attn_metadata._seq_lens_cpu is not None:
                common_attn_metadata._seq_lens_cpu += 1
            if common_attn_metadata._num_computed_tokens_cpu is not None:
                common_attn_metadata._num_computed_tokens_cpu += 1
614

615
            # Compute the slot mapping.
616
            block_size = attn_metadata_builder.kv_cache_spec.block_size
617
618
            if self.uses_mrope:
                # all dimensions of positions are the same
619
                block_numbers = clamped_positions[0] // block_size
620
            else:
621
                block_numbers = clamped_positions // block_size
622
            block_ids = common_attn_metadata.block_table_tensor.gather(
623
624
                dim=1, index=block_numbers.view(-1, 1)
            )
625
            block_ids = block_ids.view(-1)
626
627
            if self.uses_mrope:
                common_attn_metadata.slot_mapping = (
628
                    block_ids * block_size + clamped_positions[0] % block_size
629
                )
630
631
            else:
                common_attn_metadata.slot_mapping = (
632
                    block_ids * block_size + clamped_positions % block_size
633
                )
634
635
636
            # Mask out the slot mappings that exceed the max model length.
            # Otherwise, the KV cache will be inadvertently updated with the
            # padding tokens.
637
            common_attn_metadata.slot_mapping.masked_fill_(
638
639
                exceeds_max_model_len, PADDING_SLOT_ID
            )
640
641

            # Rebuild attention metadata
642
            attn_metadata = attn_metadata_builder.build_for_drafting(  # type: ignore
643
644
                common_attn_metadata=common_attn_metadata, draft_index=token_index + 1
            )
645
646
            for layer_name in self.attn_layer_names:
                per_layer_attn_metadata[layer_name] = attn_metadata
647

648
649
            # copy inputs to buffer for cudagraph
            self.input_ids[:batch_size] = input_ids
650
            self._set_positions(batch_size, clamped_positions)
651
            self.hidden_states[:batch_size] = hidden_states
652
            if self.supports_mm_inputs:
653
                self.inputs_embeds[:batch_size] = self.model.embed_input_ids(input_ids)
654

655
                input_ids = None
656
                inputs_embeds = self.inputs_embeds[:input_batch_size]
657
658
            else:
                input_ids = self.input_ids[:input_batch_size]
659
                inputs_embeds = None
660

661
            # Run the model.
662
663
664
665
666
667
668
669
            model_kwargs = {
                "input_ids": input_ids,
                "positions": self._get_positions(input_batch_size),
                "inputs_embeds": inputs_embeds,
            }
            if self.pass_hidden_states_to_model:
                model_kwargs["hidden_states"] = self.hidden_states[:input_batch_size]

670
            with set_forward_context(
671
672
673
                per_layer_attn_metadata,
                self.vllm_config,
                num_tokens=input_batch_size,
Rémi Delacourt's avatar
Rémi Delacourt committed
674
                num_tokens_across_dp=batch_size_across_dp,
675
                cudagraph_runtime_mode=cudagraph_runtime_mode,
676
677
678
                slot_mapping=self._get_slot_mapping(
                    input_batch_size, common_attn_metadata.slot_mapping
                ),
679
            ):
680
681
                ret_hidden_states = self.model(**model_kwargs)
                if not self.model_returns_tuple():
682
683
684
685
                    last_hidden_states = ret_hidden_states
                    hidden_states = ret_hidden_states
                else:
                    last_hidden_states, hidden_states = ret_hidden_states
686

687
            hidden_states = hidden_states[:batch_size]
688
            draft_token_ids = self._greedy_sample(last_hidden_states[:batch_size])
689
690
691
692
            draft_token_ids_list.append(draft_token_ids)

        # [batch_size, num_speculative_tokens]
        draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
693
        return draft_token_ids
694

695
696
697
698
699
    def set_inputs_first_pass(
        self,
        target_token_ids: torch.Tensor,
        next_token_ids: torch.Tensor,
        target_positions: torch.Tensor,
700
701
        target_hidden_states: torch.Tensor,
        token_indices_to_sample: torch.Tensor | None,
702
703
704
        cad: CommonAttentionMetadata,
        num_rejected_tokens_gpu: torch.Tensor | None,
    ) -> tuple[int, torch.Tensor, CommonAttentionMetadata]:
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
        if not self.needs_extra_input_slots:
            # Default EAGLE pathway: no reshaping of input tensors needed.
            # Simply rotate the input ids and leave the positions unchanged,
            # Inserting the next token ids at the last slot in each request.
            if token_indices_to_sample is None:
                token_indices_to_sample = cad.query_start_loc[1:] - 1

            num_tokens = target_token_ids.shape[0]
            # Shift the input ids by one token.
            # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
            self.input_ids[: num_tokens - 1] = target_token_ids[1:]
            # Replace the last token with the next token.
            # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
            self.input_ids[token_indices_to_sample] = next_token_ids

            # copy inputs to buffer for cudagraph
            if self.uses_xdrope_dim > 0 and self.draft_uses_xdrope_dim == 0:
                target_positions = target_positions[0]
            self._set_positions(num_tokens, target_positions)

            self.hidden_states[:num_tokens] = target_hidden_states

            return num_tokens, token_indices_to_sample, cad
        else:
            assert self.is_rejected_token_mask is not None
            assert self.is_masked_token_mask is not None
            # 1.
            # Call a custom triton kernel to copy input_ids and positions
            # into the correct slots in the preallocated buffers self.input_ids,
            # self.positions.
            batch_size = cad.batch_size()
            # Since we might have to copy a lot of data for prefills, we select the
            # block size based on the max query length and limit to max 256 slots/block.
            max_num_tokens_per_request = (
                cad.max_query_len + self.net_num_new_slots_per_request
            )
            BLOCK_SIZE_TOKENS = min(
                256, triton.next_power_of_2(max_num_tokens_per_request)
            )
            num_blocks = (
                max_num_tokens_per_request + BLOCK_SIZE_TOKENS - 1
            ) // BLOCK_SIZE_TOKENS
            total_num_input_tokens = target_token_ids.shape[0]
            total_num_output_tokens = total_num_input_tokens + (
                self.net_num_new_slots_per_request * batch_size
            )
751

752
753
754
755
756
            token_indices_to_sample = torch.empty(
                batch_size * self.extra_slots_per_request,
                dtype=torch.int32,
                device=self.device,
            )
757

758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
            # Destination indices to write target_hidden_states into drafting buffer.
            out_hidden_state_mapping = torch.empty(
                total_num_input_tokens, dtype=torch.int32, device=self.device
            )

            # Kernel grid: one program per request (row)
            grid = (batch_size, num_blocks)
            query_start_loc = cad.query_start_loc
            query_end_loc = cad.query_start_loc[1:] - 1
            if num_rejected_tokens_gpu is not None:
                query_end_loc = query_end_loc - num_rejected_tokens_gpu
            copy_and_expand_eagle_inputs_kernel[grid](
                # (Padded) Inputs from the target model
                target_token_ids_ptr=target_token_ids,
                target_positions_ptr=target_positions,
                next_token_ids_ptr=next_token_ids,  # sampled tokens, one per request
                # Outputs to the drafting buffers
                out_input_ids_ptr=self.input_ids,
                out_positions_ptr=self.positions,  # Doesn't support mrope for now
                out_is_rejected_token_mask_ptr=self.is_rejected_token_mask,
                out_is_masked_token_mask_ptr=self.is_masked_token_mask,
                out_new_token_indices_ptr=token_indices_to_sample,
                out_hidden_state_mapping_ptr=out_hidden_state_mapping,
                # Input metadata
                query_start_loc_ptr=query_start_loc,
                query_end_loc_ptr=query_end_loc,
                padding_token_id=0,
                parallel_drafting_token_id=self.parallel_drafting_token_id,
                # Sizing info
                # Note that we can deduce batch_size for free from the grid size
                total_input_tokens=total_num_input_tokens,
                num_padding_slots_per_request=self.extra_slots_per_request,
                shift_input_ids=self.pass_hidden_states_to_model,
                BLOCK_SIZE_TOKENS=BLOCK_SIZE_TOKENS,
            )
            if self.pass_hidden_states_to_model:
                assert self.parallel_drafting_hidden_state_tensor is not None
                self.hidden_states[out_hidden_state_mapping] = target_hidden_states
                # Use torch.where to avoid DtoH sync from boolean indexing
                mask = self.is_masked_token_mask[:total_num_output_tokens]
                torch.where(
                    mask.unsqueeze(1),
                    self.parallel_drafting_hidden_state_tensor,
                    self.hidden_states[:total_num_output_tokens],
                    out=self.hidden_states[:total_num_output_tokens],
                )

            # 2.
            # Recompute the slot mapping based on the new positions and
            # rejection mask.
            builder = (
                self._get_attention_metadata_builder()
                if self.attn_metadata_builder is None
                else self.attn_metadata_builder
            )
            new_slot_mapping = compute_new_slot_mapping(
                cad=cad,
                new_positions=self.positions[:total_num_output_tokens],
                is_rejected_token_mask=self.is_rejected_token_mask[
                    :total_num_output_tokens
                ],
                block_size=builder.kv_cache_spec.block_size,
                num_new_tokens=self.net_num_new_slots_per_request,
                max_model_len=self.max_model_len,
            )

            # 3. Update the common attention metadata with the new (meta)data
            new_cad = extend_all_queries_by_N(
                cad,
                N=self.net_num_new_slots_per_request,
                arange=self.arange,
                new_slot_mapping=new_slot_mapping,
            )
831

832
            return total_num_output_tokens, token_indices_to_sample, new_cad
833
834
835
836

    def model_returns_tuple(self) -> bool:
        return self.method not in ("mtp", "draft_model")

837
    def prepare_next_token_ids_cpu(
838
        self,
839
        sampled_token_ids: list[list[int]],
840
841
842
843
        requests: dict[str, CachedRequestState],
        gpu_input_batch: InputBatch,
        num_scheduled_tokens: dict[str, int],
    ) -> torch.Tensor:
844
845
846
847
848
849
850
851
852
853
        """
        This function is used to prepare the inputs for speculative decoding.
        It calculates the next token ids for each request based on the sampled
        token ids from the CPU. If a request has no sampled token ids (e.g.,
        during the initial decoding steps), it falls back to using the request
        state to get the next token id.
        """
        req_ids = gpu_input_batch.req_ids
        next_token_ids: list[int] = []
        for i, token_ids in enumerate(sampled_token_ids):
854
            if token_ids:
855
856
857
858
859
860
861
                # Common case.
                next_token_id = token_ids[-1]
            else:
                # Partial prefill (rare case).
                # Get the next token id from the request state.
                req_id = req_ids[i]
                req_state = requests[req_id]
862
                seq_len = req_state.num_computed_tokens + num_scheduled_tokens[req_id]
863
864
                next_token_id = req_state.get_token_id(seq_len)
            next_token_ids.append(next_token_id)
865
        next_token_ids = torch.tensor(
866
867
            next_token_ids, dtype=torch.int32, device=self.input_ids.device
        )
868
        return next_token_ids
869

870
871
872
873
874
875
    def prepare_next_token_ids_padded(
        self,
        common_attn_metadata: CommonAttentionMetadata,
        sampled_token_ids: torch.Tensor,
        requests: dict[str, CachedRequestState],
        gpu_input_batch: InputBatch,
876
        discard_request_mask: torch.Tensor,
877
    ) -> tuple[torch.Tensor, torch.Tensor]:
878
879
880
881
        """
        This function is used to prepare the inputs for speculative decoding.
        It calculates the next token ids and the number of valid sampled tokens
        for each request, considering the "discarded" requests whose next token
882
883
        is not sampled and comes from `request.get_token_id()` instead. This is denoted
        the "backup" token id. It also counts rejected tokens via `sampled_token_ids`.
884
885
886
        """
        # Precompute get_token_id for when there is no valid next token
        num_reqs = gpu_input_batch.num_reqs
887
888
889
890
891
892
        self.backup_next_token_ids.np[:num_reqs] = np.array(
            [
                requests[gpu_input_batch.req_ids[i]].get_token_id(
                    common_attn_metadata.seq_lens_cpu[i].item()
                )
                for i in range(num_reqs)
893
894
            ],
            dtype=np.int32,
895
        )
896
        self.backup_next_token_ids.copy_to_gpu(num_reqs)
897
        backup_tokens_gpu = self.backup_next_token_ids.gpu
898

899
900
        batch_size, num_tokens = sampled_token_ids.shape
        device = sampled_token_ids.device
901

902
903
        assert discard_request_mask.dtype == torch.bool
        assert backup_tokens_gpu.dtype == torch.int32
904

905
906
        next_token_ids = torch.empty(batch_size, dtype=torch.int32, device=device)
        valid_sampled_tokens_count = next_token_ids.new_empty(batch_size)
907

908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
        # Kernel grid: one program per request (row)
        grid = (batch_size,)

        # Find the next power of 2 for block sizes
        BLOCK_SIZE_TOKENS = triton.next_power_of_2(num_tokens)
        eagle_prepare_next_token_padded_kernel[grid](
            sampled_token_ids,
            discard_request_mask,
            backup_tokens_gpu,
            next_token_ids,
            valid_sampled_tokens_count,
            gpu_input_batch.vocab_size,
            num_tokens,
            batch_size,
            sampled_token_ids.stride(0),
            BLOCK_SIZE_TOKENS=BLOCK_SIZE_TOKENS,
924
        )
925
926
927

        return next_token_ids, valid_sampled_tokens_count

928
929
930
931
932
    def prepare_inputs_padded(
        self,
        common_attn_metadata: CommonAttentionMetadata,
        spec_decode_metadata: SpecDecodeMetadata,
        valid_sampled_tokens_count: torch.Tensor,
933
    ) -> tuple[CommonAttentionMetadata, torch.Tensor, torch.Tensor]:
934
935
936
937
938
939
        """
        This function is used to prepare the inputs for speculative decoding
        It updates the common_attn_metadata for speculative decoding,
        but does not consider the rejected tokens. Instead, all tokens
        are included as inputs to the speculator, with the rejected tokens
        used as padding and filtered out later by `token_indices_to_sample`.
940
        No blocking CPU operations should be introduced in this function.
941
        """
942
943
944
945
946
        num_reqs = common_attn_metadata.num_reqs
        device = valid_sampled_tokens_count.device

        token_indices_to_sample = torch.empty(
            (num_reqs,), dtype=torch.int32, device=device
947
        )
948
949
950
        num_rejected_tokens_gpu = torch.empty(
            (num_reqs,), dtype=torch.int32, device=device
        )
951

952
953
954
955
956
957
        grid = (num_reqs,)
        eagle_prepare_inputs_padded_kernel[grid](
            spec_decode_metadata.cu_num_draft_tokens,
            valid_sampled_tokens_count,
            common_attn_metadata.query_start_loc,
            token_indices_to_sample,
958
            num_rejected_tokens_gpu,
959
            num_reqs,
960
        )
961
962

        query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
963
        new_query_len_per_req = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
964
965
966
967
968
969
970

        total_num_tokens = query_start_loc_cpu[-1].item()

        spec_common_attn_metadata = CommonAttentionMetadata(
            query_start_loc=common_attn_metadata.query_start_loc,
            seq_lens=common_attn_metadata.seq_lens,
            query_start_loc_cpu=query_start_loc_cpu,
971
972
            _seq_lens_cpu=common_attn_metadata._seq_lens_cpu,
            _num_computed_tokens_cpu=common_attn_metadata._num_computed_tokens_cpu,
973
974
975
976
977
            num_reqs=common_attn_metadata.num_reqs,
            num_actual_tokens=total_num_tokens,
            max_query_len=new_query_len_per_req.max().item(),
            max_seq_len=common_attn_metadata.seq_lens_cpu.max().item(),
            block_table_tensor=common_attn_metadata.block_table_tensor,
978
            slot_mapping=common_attn_metadata.slot_mapping[:total_num_tokens],
979
            causal=True,
980
            dcp_local_seq_lens=common_attn_metadata.dcp_local_seq_lens,
981
982
        )

983
984
985
986
987
        return (
            spec_common_attn_metadata,
            token_indices_to_sample,
            num_rejected_tokens_gpu,
        )
988

989
990
991
992
993
994
995
996
997
998
    def propose_tree(
        self,
        batch_size: int,
        # [num_tokens, vocab_size]
        logits: torch.Tensor,
        # [num_tokens]
        positions: torch.Tensor,
        # [num_tokens, hidden_size]
        hidden_states: torch.Tensor,
        common_attn_metadata: CommonAttentionMetadata,
999
1000
1001
        slot_mappings: dict[str, torch.Tensor]
        | list[dict[str, torch.Tensor]]
        | None = None,
1002
    ) -> list[torch.Tensor]:
1003
1004
1005
1006
        tree_attn_metadata_builder = self.runner.attn_groups[0][
            0
        ].get_metadata_builder()
        assert isinstance(tree_attn_metadata_builder, TreeAttentionMetadataBuilder)
1007

1008
        total_num_drafts = self.cu_drafts_per_level[0]
1009
1010
        level_num_drafts = total_num_drafts
        # Sample a draft token for each child at the tree root level.
1011
        num_children = self.child_drafts_per_level[0]
1012
1013
1014
        if num_children == 1:
            draft_token_ids = logits.argmax(dim=-1).view(batch_size, -1)
        else:
1015
1016
1017
            draft_token_ids = torch.topk(logits, num_children, dim=-1).indices.view(
                batch_size, -1
            )
1018
1019
1020
1021
        draft_token_ids_list = [draft_token_ids]
        draft_hidden_states = hidden_states.view(batch_size, 1, -1)

        # Initialize empty tensors for concatenation with the level outputs.
1022
1023
1024
1025
1026
1027
1028
1029
1030
        tree_input_ids = torch.empty(
            0, device=self.input_ids.device, dtype=self.input_ids.dtype
        )
        tree_positions = torch.empty(
            0, device=self.positions.device, dtype=self.positions.dtype
        )
        tree_hidden_states = torch.empty(
            0, device=self.hidden_states.device, dtype=self.hidden_states.dtype
        )
1031
1032
        # Precompute the draft token positions.
        flattened_draft_positions = (
1033
1034
            positions.view(batch_size, -1) + self.tree_draft_pos_offsets[:batch_size, :]
        )
1035
        tree_depth = len(self.cu_drafts_per_level)
1036
        for level in range(tree_depth - 1):
1037
1038
            # Get draft positions for RoPE.
            draft_positions = positions + (level + 1)
1039
            exceeds_max_model_len = (positions + total_num_drafts) >= self.max_model_len
1040
1041
            # Mask out the position ids that exceed the max model length.
            # Otherwise, we may get out-of-range error in RoPE.
1042
            draft_positions = torch.where(
1043
1044
1045
                exceeds_max_model_len,
                0,
                draft_positions,
1046
1047
            ).view(batch_size, -1)

1048
1049
            if level_num_drafts > 1:
                # Repeat the positions for each draft at this level.
1050
                draft_positions = draft_positions.repeat_interleave(
1051
1052
                    level_num_drafts, dim=1
                )
1053
1054
1055
1056

            if num_children > 1:
                # Repeat draft hidden states for each child.
                draft_hidden_states = draft_hidden_states.repeat_interleave(
1057
1058
                    num_children, dim=1
                )
1059
1060

            # Concatenate the draft tokens, positions, and hidden states.
1061
1062
            tree_input_ids = torch.cat([tree_input_ids, draft_token_ids], dim=1)
            tree_positions = torch.cat([tree_positions, draft_positions], dim=1)
1063
            tree_hidden_states = torch.cat(
1064
1065
                [tree_hidden_states, draft_hidden_states], dim=1
            )
1066
1067
1068

            # Build new attention metadata for the next level of drafts.
            # This is necessary to support tree attention.
1069
            query_len = total_num_drafts
1070
1071
            common_attn_metadata = replace(
                common_attn_metadata,
1072
                query_start_loc=query_len * self.arange[: batch_size + 1],
1073
1074
1075
1076
1077
                seq_lens=common_attn_metadata.seq_lens + level_num_drafts,
                num_actual_tokens=batch_size * query_len,
                max_query_len=query_len,
            )
            attn_metadata = tree_attn_metadata_builder.build_for_drafting(
1078
                common_attn_metadata=common_attn_metadata, draft_index=level + 1
1079
1080
1081
1082
1083
1084
1085
1086
            )

            # Apply new attention metadata to all layers.
            per_layer_attn_metadata = {}
            for layer_name in self.attn_layer_names:
                per_layer_attn_metadata[layer_name] = attn_metadata

            # Consider max model length.
1087
1088
1089
            attn_metadata.max_seq_len = min(
                attn_metadata.max_seq_len, self.max_model_len
            )
1090
1091
1092
1093
1094
            # For the requests that exceed the max model length, we set the
            # sequence length to 1 to minimize their overheads in attention.
            attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1)

            # Compute the slot mapping.
1095
            block_size = tree_attn_metadata_builder.kv_cache_spec.block_size
1096
            query_positions = flattened_draft_positions[:, level : level + query_len]
1097
            block_numbers = query_positions // block_size
1098
            block_ids = attn_metadata.block_table.gather(dim=1, index=block_numbers)
1099
            slot_mapping = block_ids * block_size + query_positions % block_size
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
            # Mask out the slot mappings that exceed the max model length.
            # Otherwise, the KV cache will be inadvertently updated with the
            # padding tokens.
            slot_mapping[exceeds_max_model_len] = PADDING_SLOT_ID
            attn_metadata.slot_mapping = slot_mapping.view(-1)

            # Copy inputs to buffer for cudagraph.
            num_tokens = attn_metadata.num_actual_tokens
            input_ids = tree_input_ids.view(-1)
            self.input_ids[:num_tokens] = input_ids
            self.positions[:num_tokens] = tree_positions.view(-1)
1111
            self.hidden_states[:num_tokens] = tree_hidden_states.view(num_tokens, -1)
1112

1113
1114
1115
1116
            cudagraph_runtime_mode, batch_desc = self.cudagraph_dispatcher.dispatch(
                num_tokens
            )
            num_input_tokens = batch_desc.num_tokens
1117
            # Run the model.
1118
            with set_forward_context(
1119
1120
1121
1122
                per_layer_attn_metadata,
                self.vllm_config,
                num_tokens=num_input_tokens,
                cudagraph_runtime_mode=cudagraph_runtime_mode,
1123
1124
1125
                slot_mapping=self._get_slot_mapping(
                    num_input_tokens, attn_metadata.slot_mapping
                ),
1126
            ):
1127
1128
1129
1130
1131
1132
1133
1134
1135
                last_hidden_states, hidden_states = self.model(
                    input_ids=self.input_ids[:num_input_tokens],
                    positions=self.positions[:num_input_tokens],
                    hidden_states=self.hidden_states[:num_input_tokens],
                    inputs_embeds=None,
                )

            # Get the output hidden states for the draft tokens.
            draft_hidden_states = hidden_states[:num_tokens].view(
1136
1137
                batch_size, query_len, -1
            )[:, -level_num_drafts:]
1138
            draft_last_hidden_states = last_hidden_states[:num_tokens].view(
1139
1140
                batch_size, query_len, -1
            )[:, -level_num_drafts:]
1141
1142
1143

            # Get the output logits for the draft tokens.
            logits = self.model.compute_logits(
1144
1145
                draft_last_hidden_states.reshape(batch_size * level_num_drafts, -1)
            )
1146
1147
1148
1149
1150
1151

            # Sample a draft token for each child at the next tree level.
            num_children = self.child_drafts_per_level[level + 1]
            if num_children == 1:
                draft_token_ids = logits.argmax(dim=-1).view(batch_size, -1)
            else:
1152
1153
1154
                draft_token_ids = torch.topk(logits, num_children, dim=-1).indices.view(
                    batch_size, -1
                )
1155
1156
1157
            draft_token_ids_list.append(draft_token_ids)

            # Update the # drafts counters for the next tree level.
1158
            level_num_drafts = self.cu_drafts_per_level[level + 1] - total_num_drafts
1159
1160
1161
            total_num_drafts = self.cu_drafts_per_level[level + 1]
        return draft_token_ids_list

1162
    def prepare_inputs(
1163
1164
        self,
        common_attn_metadata: CommonAttentionMetadata,
1165
1166
        sampled_token_ids: list[list[int]],
        num_draft_tokens: list[int],
1167
1168
    ) -> tuple[CommonAttentionMetadata, torch.Tensor]:
        """
1169
        This function is used to prepare the inputs for speculative decoding.
1170
1171
1172
1173
1174
1175
        It updates to the common_attn_metadata to account for the rejected
        tokens (and newly sampled tokens). It also returns the token indices
        of the tokens that should be fed to the speculator.
        """
        # E.g.
        #  common_attn_metadata.query_start_loc{_cpu}:
1176
        #       [0, q1, q1 + q2, q1 + q2 + q3]
1177
1178
1179
1180
1181
1182
        #  common_attn_metadata.seq_lens{_cpu}: [s1, s2, s3]
        #  num_rejected_tokens: [n1, n2, n3]
        # This function computes the intermediate values:
        #  num_tokens_per_req: [q1 - n1, q2 - n2, q3 - n3]
        # And returns:
        #  common_attn_metadata.query_start_loc{_cpu}:
1183
        #       [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
1184
        #  common_attn_metadata.seq_lens{_cpu}:
1185
        #       [s1 - n1 + 1, s2 - n2 + 1, s3 - n3 + 1]
1186
        #  token_indices: [0, 1, ..., q1 - n1 - 1,
1187
1188
        #                 q1, q1 + 1, ..., q1 + q2 - n2 - 1,
        #                 q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1]
1189

1190
1191
1192
1193
        num_rejected_tokens = [
            n + 1 - len(sampled_token_ids[i]) if n > 0 else 0
            for i, n in enumerate(num_draft_tokens)
        ]
1194
        num_rejected_tokens = torch.tensor(num_rejected_tokens, dtype=torch.int32)
1195

1196
1197
        device = common_attn_metadata.query_start_loc.device
        query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
1198
        new_seq_lens_cpu = common_attn_metadata.seq_lens_cpu - num_rejected_tokens
1199
1200

        # [0, q1, q1 + q2, q1 + q2 + q3] -> [q1, q2, q3]
1201
        new_query_len_per_req = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
1202
1203
1204
1205
1206
1207
1208
1209
        # [q1, q2, q3] -> [q1 - n1, q2 - n2, q3 - n3]
        new_num_tokens_per_req = new_query_len_per_req - num_rejected_tokens
        new_num_tokens_per_req_np = new_num_tokens_per_req.numpy()

        # [q1 - n1, q2 - n2, q3 - n3] ->
        # [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
        new_query_start_loc_cpu = torch.zeros(
            query_start_loc_cpu.shape,
1210
            dtype=torch.int32,
1211
1212
            pin_memory=is_pin_memory_available(),
        )
1213
1214
1215
1216
1217
1218
1219
1220
1221
        new_query_start_loc_np = new_query_start_loc_cpu.numpy()
        np.cumsum(new_num_tokens_per_req_np, out=new_query_start_loc_np[1:])

        total_num_tokens = new_query_start_loc_np[-1]
        # Example assuming num_tokens_per_req_np = [2, 4, 3]
        # this implies that `new_query_start_locs` is:
        # [0, 2, 6, 9] ->
        # [0, 0, 2, 2, 2, 2, 6, 6, 6]
        #  _r1_  ____r2____  ___r3__
1222
1223
1224
        new_query_start_locs_expanded = np.repeat(
            new_query_start_loc_np[:-1], new_num_tokens_per_req_np
        )
1225
1226
1227
        # [0, 1, 2, 3, 4, 5, 6, 7, 8] ->
        # [0, 1, 0, 1, 2, 3, 0, 1, 2]
        #  _r1_  ____r2____  ___r3__
1228
        token_offsets = (
1229
1230
            self.token_arange_np[:total_num_tokens] - new_query_start_locs_expanded
        )
1231
1232
1233
1234
1235
1236

        # Expand starting positions to match token pattern
        # [0, q1, q1 + q2] ->
        # [0, 0, q1, q1, q1, q1, q1 + q2, q1 + q2, q1 + q2]
        #  _r1_  _____r2_______  ___________r3____________
        old_query_start_locs_expanded = np.repeat(
1237
1238
            query_start_loc_cpu[:-1].numpy(), new_num_tokens_per_req_np
        )
1239
        # Final token indices are:
1240
1241
1242
        # [0, 1,                                // req 1
        #  q1 + 0, q1 + 1, q1 + 2, q1 + 3,       // req 2
        #  q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3
1243
        token_indices_np = token_offsets + old_query_start_locs_expanded
1244
        token_indices = torch.from_numpy(token_indices_np).to(device, non_blocking=True)
1245
1246

        spec_common_attn_metadata = CommonAttentionMetadata(
1247
            query_start_loc=new_query_start_loc_cpu.to(device, non_blocking=True),
1248
1249
            seq_lens=new_seq_lens_cpu.to(device, non_blocking=True),
            query_start_loc_cpu=new_query_start_loc_cpu,
1250
1251
            _seq_lens_cpu=new_seq_lens_cpu,
            _num_computed_tokens_cpu=common_attn_metadata._num_computed_tokens_cpu,
1252
1253
1254
            num_reqs=common_attn_metadata.num_reqs,
            num_actual_tokens=total_num_tokens,
            max_query_len=new_query_len_per_req.max().item(),
1255
            max_seq_len=new_seq_lens_cpu.max().item(),
1256
1257
            block_table_tensor=common_attn_metadata.block_table_tensor,
            slot_mapping=common_attn_metadata.slot_mapping[token_indices],
1258
            causal=True,
1259
            dcp_local_seq_lens=common_attn_metadata.dcp_local_seq_lens,
1260
        )
1261
1262

        return spec_common_attn_metadata, token_indices
1263

1264
    def get_model_name(self, model: nn.Module) -> str:
1265
        if hasattr(model, "module"):  # multi-GPU
1266
1267
1268
            model = model.module
        return model.__class__.__name__

1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
    def _get_model(self) -> nn.Module:
        """
        Default method to call get_model(). Can be overridden by subclasses which
        need to customize model loading.
        """
        from vllm.compilation.backends import set_model_tag

        with set_model_tag("eagle_head"):
            model = get_model(
                vllm_config=self.vllm_config,
                model_config=self.speculative_config.draft_model_config,
1280
                load_config=self.speculative_config.draft_load_config,
1281
1282
1283
            )
        return model

1284
    def load_model(self, target_model: nn.Module) -> None:
1285
        target_attn_layer_names = set(
1286
1287
1288
1289
            get_layers_from_vllm_config(
                self.vllm_config,
                AttentionLayerBase,  # type: ignore[type-abstract]
            ).keys()
1290
        )
1291
1292
        # FIXME: support hybrid kv for draft model
        target_indexer_layer_names = set(
1293
1294
1295
1296
            get_layers_from_vllm_config(
                self.vllm_config, DeepseekV32IndexerCache
            ).keys()
        )
1297

1298
        self.model = self._get_model()
1299

1300
        draft_attn_layer_names = (
1301
1302
1303
1304
            get_layers_from_vllm_config(
                self.vllm_config,
                AttentionLayerBase,  # type: ignore[type-abstract]
            ).keys()
1305
1306
1307
1308
1309
1310
            - target_attn_layer_names
        )
        indexer_layers = get_layers_from_vllm_config(
            self.vllm_config, DeepseekV32IndexerCache
        )
        draft_indexer_layer_names = indexer_layers.keys() - target_indexer_layer_names
1311
        self.attn_layer_names = list(draft_attn_layer_names - draft_indexer_layer_names)
1312
1313
1314
1315
1316
        self.indexer_layer_names = list(draft_indexer_layer_names)

        if self.indexer_layer_names:
            first_layer = self.indexer_layer_names[0]
            self.draft_indexer_metadata_builder = (
1317
1318
1319
                indexer_layers[first_layer]
                .get_attn_backend()
                .get_builder_cls()(
1320
                    indexer_layers[first_layer].get_kv_cache_spec(self.vllm_config),
1321
1322
1323
                    self.indexer_layer_names,
                    self.vllm_config,
                    self.device,
1324
1325
                )
            )
1326
1327
        else:
            self.draft_indexer_metadata_builder = None
1328

1329
        if self.supports_mm_inputs:
1330
1331
1332
            # Even if the target model is multimodal, we can also use
            # text-only draft models
            try:
1333
                dummy_input_ids = torch.tensor([[1]], device=self.input_ids.device)
1334
                self.model.embed_input_ids(dummy_input_ids, multimodal_embeddings=None)
1335
1336
1337
            except (NotImplementedError, AttributeError, TypeError):
                logger.warning(
                    "Draft model does not support multimodal inputs, "
1338
1339
                    "falling back to text-only mode"
                )
1340
                self.supports_mm_inputs = False
1341

1342
1343
        if supports_multimodal(target_model):
            # handle multimodality
1344
            assert hasattr(target_model, "config")
1345
1346
1347
            if self.get_model_name(target_model) in [
                "Qwen2_5_VLForConditionalGeneration",
                "Qwen3VLForConditionalGeneration",
1348
                "Qwen3VLMoeForConditionalGeneration",
1349
                "HunYuanVLForConditionalGeneration",
1350
                "GlmOcrForConditionalGeneration",
1351
1352
                "Qwen3_5ForConditionalGeneration",
                "Qwen3_5MoeForConditionalGeneration",
1353
            ]:
1354
                self.model.config.image_token_index = target_model.config.image_token_id
1355
1356
1357
1358
            elif self.get_model_name(target_model) == "PixtralForConditionalGeneration":
                self.model.config.image_token_index = (
                    target_model.config.vision_config.image_token_id
                )
1359
1360
            else:
                self.model.config.image_token_index = (
1361
1362
                    target_model.config.image_token_index
                )
1363
1364
1365
            target_language_model = cast(
                SupportsMultiModal, target_model
            ).get_language_model()
1366
1367
        else:
            target_language_model = target_model
1368

1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
        self._maybe_share_embeddings(target_language_model)
        self._maybe_share_lm_head(target_language_model)

        if self.parallel_drafting and self.pass_hidden_states_to_model:
            assert self.parallel_drafting_hidden_state_tensor is not None
            self.parallel_drafting_hidden_state_tensor.copy_(
                self.model.combine_hidden_states(
                    self.model.mask_hidden.view(3 * self.hidden_size)
                )
                if self.eagle3_use_aux_hidden_state
                else self.model.mask_hidden.view(self.hidden_size)
            )

    def _maybe_share_embeddings(self, target_language_model: nn.Module) -> None:
        """
        Some draft models may not have their own embedding layers, and some may
        have a duplicate copy of the target model's embedding layers. In these cases,
        we share the target model's embedding layers with the draft model to save
        memory.
        """
1389
        if get_pp_group().world_size == 1:
1390
1391
1392
1393
1394
1395
1396
            inner_model = getattr(target_language_model, "model", None)
            if inner_model is None:
                raise AttributeError("Target model does not have 'model' attribute")
            if hasattr(inner_model, "embed_tokens"):
                target_embed_tokens = inner_model.embed_tokens
            elif hasattr(inner_model, "embedding"):
                target_embed_tokens = inner_model.embedding
1397
1398
            else:
                raise AttributeError(
1399
1400
                    "Target model does not have 'embed_tokens' or 'embedding' attribute"
                )
1401

1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
            share_embeddings = False
            if hasattr(self.model, "has_own_embed_tokens"):
                # EAGLE model
                if not self.model.has_own_embed_tokens:
                    share_embeddings = True
                    logger.info(
                        "Detected EAGLE model without its own embed_tokens in the"
                        " checkpoint. Sharing target model embedding weights with the"
                        " draft model."
                    )
                elif (
                    isinstance(target_embed_tokens.weight, torch.Tensor)
                    and isinstance(self.model.model.embed_tokens.weight, torch.Tensor)
1415
1416
1417
                    # TODO: Offload to CPU for comparison to avoid extra GPU memory
                    # usage in CI testing environments with limited GPU memory
                    and torch.equal(
1418
1419
                        target_embed_tokens.weight.cpu(),
                        self.model.model.embed_tokens.weight.cpu(),
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
                    )
                ):
                    share_embeddings = True
                    logger.info(
                        "Detected EAGLE model with embed_tokens identical to the target"
                        " model. Sharing target model embedding weights with the draft"
                        " model."
                    )
                else:
                    logger.info(
                        "Detected EAGLE model with distinct embed_tokens weights. "
                        "Keeping separate embedding weights from the target model."
                    )
1433
            else:
1434
1435
                # MTP model
                share_embeddings = True
1436
                logger.info(
1437
1438
                    "Detected MTP model. "
                    "Sharing target model embedding weights with the draft model."
1439
                )
1440
1441
1442
1443
1444

            if share_embeddings:
                if hasattr(self.model.model, "embed_tokens"):
                    del self.model.model.embed_tokens
                self.model.model.embed_tokens = target_embed_tokens
1445
        else:
1446
            logger.info(
1447
                "The draft model's vocab embedding will be loaded separately"
1448
1449
                " from the target model."
            )
1450

1451
1452
1453
1454
1455
1456
    def _maybe_share_lm_head(self, target_language_model: nn.Module) -> None:
        """
        Some draft models may not have their own LM head, and some may have a
        duplicate copy of the target model's LM head. In these cases, we share
        the target model's LM head with the draft model to save memory.
        """
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
        share_lm_head = False
        if hasattr(self.model, "has_own_lm_head"):
            # EAGLE model
            if not self.model.has_own_lm_head:
                share_lm_head = True
                logger.info(
                    "Detected EAGLE model without its own lm_head in the checkpoint. "
                    "Sharing target model lm_head weights with the draft model."
                )
            elif (
                hasattr(target_language_model, "lm_head")
                and isinstance(target_language_model.lm_head.weight, torch.Tensor)
                and isinstance(self.model.lm_head.weight, torch.Tensor)
1470
1471
                # TODO: Offload to CPU for comparison to avoid extra GPU memory
                # usage in CI testing environments with limited GPU memory
1472
                and torch.equal(
1473
1474
                    target_language_model.lm_head.weight.cpu(),
                    self.model.lm_head.weight.cpu(),
1475
                )
1476
            ):
1477
                share_lm_head = True
1478
                logger.info(
1479
1480
                    "Detected EAGLE model with lm_head identical to the target model. "
                    "Sharing target model lm_head weights with the draft model."
1481
                )
1482
1483
            else:
                logger.info(
1484
1485
                    "Detected EAGLE model with distinct lm_head weights. "
                    "Keeping separate lm_head weights from the target model."
1486
                )
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
        else:
            # MTP model
            share_lm_head = True
            logger.info(
                "Detected MTP model. "
                "Sharing target model lm_head weights with the draft model."
            )

        if share_lm_head and hasattr(target_language_model, "lm_head"):
            if hasattr(self.model, "lm_head"):
                del self.model.lm_head
            self.model.lm_head = target_language_model.lm_head
1499

1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
            # MTP models call compute_logits via shared_head.head (a
            # ParallelLMHead inside each MTP layer), not self.model.lm_head.
            # If the checkpoint omits a copy of the lm_head weights at the
            # MTP layer path, shared_head.head stays uninitialised and
            # produces NaN logits. Always share it explicitly.
            inner = getattr(self.model, "model", None)
            layers = getattr(inner, "layers", None) if inner else None
            if layers is not None:
                items = layers.values() if isinstance(layers, nn.ModuleDict) else layers
                for layer in items:
                    sh = getattr(layer, "shared_head", None)
                    if sh is not None and hasattr(sh, "head"):
                        del sh.head
                        sh.head = target_language_model.lm_head
                        logger.info(
                            "Shared target model lm_head with MTP shared_head.head."
                        )

1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
        if self.use_local_argmax_reduction:
            if not hasattr(self.model, "get_top_tokens"):
                raise ValueError(
                    "use_local_argmax_reduction is enabled but draft model "
                    f"{self.model.__class__.__name__} does not implement "
                    "get_top_tokens()."
                )
            # Warn if draft model has vocab remapping, which forces fallback
            # to the full-logits path (negating the optimization).
            if (
                hasattr(self.model, "draft_id_to_target_id")
                and self.model.draft_id_to_target_id is not None
            ):
                logger.warning(
                    "use_local_argmax_reduction is enabled but draft model "
                    "uses draft_id_to_target_id vocab remapping. The "
                    "optimization will be bypassed (falling back to full "
                    "logits gather + argmax)."
                )
            else:
                logger.info(
                    "Using local argmax reduction for draft token generation "
                    "(communication: O(2*tp_size) vs O(vocab_size))."
                )

1543
1544
1545
1546
    @torch.inference_mode()
    def dummy_run(
        self,
        num_tokens: int,
1547
1548
        use_cudagraphs: bool = True,
        is_graph_capturing: bool = False,
1549
        slot_mappings: dict[str, torch.Tensor] | None = None,
1550
    ) -> None:
Rémi Delacourt's avatar
Rémi Delacourt committed
1551
1552
1553
1554
        # FIXME: when using tree-based specdec, adjust number of forward-passes
        # according to the depth of the tree.
        for fwd_idx in range(
            self.num_speculative_tokens if not is_graph_capturing else 1
1555
        ):
Rémi Delacourt's avatar
Rémi Delacourt committed
1556
            if fwd_idx <= 1:
1557
1558
1559
                cudagraph_runtime_mode, num_input_tokens, num_tokens_across_dp = (
                    self._determine_batch_execution_and_padding(
                        num_tokens, use_cudagraphs=use_cudagraphs
1560
                    )
1561
                )
1562

1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
            # Make sure to use EAGLE's own buffer during cudagraph capture.
            if (
                self.attn_layer_names
                and slot_mappings is not None
                and self.attn_layer_names[0] in slot_mappings
            ):
                slot_mapping_dict = self._get_slot_mapping(num_input_tokens)
            else:
                slot_mapping_dict = slot_mappings or {}

Rémi Delacourt's avatar
Rémi Delacourt committed
1573
1574
1575
1576
1577
            with set_forward_context(
                None,
                self.vllm_config,
                num_tokens=num_input_tokens,
                num_tokens_across_dp=num_tokens_across_dp,
1578
                cudagraph_runtime_mode=cudagraph_runtime_mode,
1579
                slot_mapping=slot_mapping_dict,
Rémi Delacourt's avatar
Rémi Delacourt committed
1580
1581
1582
1583
1584
1585
1586
1587
            ):
                if self.supports_mm_inputs:
                    input_ids = None
                    inputs_embeds = self.inputs_embeds[:num_input_tokens]
                else:
                    input_ids = self.input_ids[:num_input_tokens]
                    inputs_embeds = None

1588
                kwargs = dict(
Rémi Delacourt's avatar
Rémi Delacourt committed
1589
1590
1591
1592
                    input_ids=input_ids,
                    positions=self._get_positions(num_input_tokens),
                    inputs_embeds=inputs_embeds,
                )
1593
1594
1595
                if self.pass_hidden_states_to_model:
                    kwargs["hidden_states"] = self.hidden_states[:num_input_tokens]
                self.model(**kwargs)
1596

1597
    def _get_attention_metadata_builder(self) -> AttentionMetadataBuilder:
1598
        """Find and return the attention metadata builders for EAGLE layers.
1599

1600
1601
        Returns:
            The metadata builders for EAGLE layers.
1602

1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
        Raises:
            AssertionError: If no metadata builders are found for EAGLE layers.
        """
        builder = None
        chosen_layer = self.attn_layer_names[0]

        for kv_cache_group in self.runner.attn_groups:
            for attn_group in kv_cache_group:
                if chosen_layer in attn_group.layer_names:
                    builder = attn_group.get_metadata_builder()
                    break
            if builder is not None:
                break

        assert builder is not None, (
1618
1619
            "Failed to find attention metadata builder for EAGLE layers."
        )
1620
1621
        return builder

1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
    def _get_eagle3_use_aux_hidden_state_from_config(self) -> bool:
        """
        Some eagle3 heads (e.g., nvidia/gpt-oss-120b-Eagle3-v2) do not use auxiliary
        hidden states and directly uses the last layer output just like eagle1.
        They might indicate this by setting "use_aux_hidden_state" to False
        inside the "eagle_config" dict of their hf_config.
        """
        if self.method != "eagle3":
            return False
        # Assume that eagle3 heads use aux hidden states by default
        use_aux_hidden_state = True
        eagle_config = getattr(self.draft_model_config.hf_config, "eagle_config", None)
        if eagle_config is not None:
            use_aux_hidden_state = eagle_config.get("use_aux_hidden_state", True)
        return use_aux_hidden_state

1638
    def validate_same_kv_cache_group(self, kv_cache_config: KVCacheConfig) -> None:
1639
        """
1640
1641
        Validate that all drafting layers belong to the same KVCacheGroup.
        Need this assumption to ensure all drafting layers can use the
1642
1643
1644
1645
1646
1647
1648
        same AttentionMetadata.
        May extend to multiple AttentionMetadata in the future.
        """
        kv_cache_groups: dict[str, int] = {}
        for id, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups):
            for layer_name in kv_cache_group.layer_names:
                kv_cache_groups[layer_name] = id
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
        assert (
            len(
                set(
                    [
                        kv_cache_groups[layer_name]
                        for layer_name in self.attn_layer_names
                    ]
                )
            )
            == 1
1659
        ), "All drafting layers should belong to the same kv cache group"
1660

1661
    def _determine_batch_execution_and_padding(
Rémi Delacourt's avatar
Rémi Delacourt committed
1662
        self,
1663
1664
1665
1666
1667
1668
        num_tokens: int,
        use_cudagraphs: bool = True,
    ) -> tuple[CUDAGraphMode, int, torch.Tensor | None]:
        cudagraph_mode, batch_desc = self.cudagraph_dispatcher.dispatch(
            num_tokens,
            valid_modes=({CUDAGraphMode.NONE} if not use_cudagraphs else None),
Rémi Delacourt's avatar
Rémi Delacourt committed
1669
        )
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
        num_tokens_padded = batch_desc.num_tokens

        # Extra coordination when running data-parallel since we need to
        # coordinate across ranks
        # TODO(Flechman): support DBO ubatching
        should_ubatch, num_tokens_across_dp = False, None
        if self.vllm_config.parallel_config.data_parallel_size > 1:
            should_ubatch, num_tokens_across_dp, synced_cudagraph_mode = (
                coordinate_batch_across_dp(
                    num_tokens_unpadded=num_tokens,
                    parallel_config=self.vllm_config.parallel_config,
                    allow_microbatching=False,
                    num_tokens_padded=num_tokens_padded,
                    cudagraph_mode=cudagraph_mode.value,
                )
            )
            assert not should_ubatch, "DBO ubatching not implemented for EAGLE"

            # Extract DP-synced values
            if num_tokens_across_dp is not None:
                dp_rank = self.dp_rank
                num_tokens_padded = int(num_tokens_across_dp[dp_rank].item())
                # Re-dispatch with DP padding so we have the correct
                # batch_descriptor
                cudagraph_mode, batch_desc = self.cudagraph_dispatcher.dispatch(
                    num_tokens_padded,
                    valid_modes={CUDAGraphMode(synced_cudagraph_mode)},
                )
                # Assert to make sure the agreed upon token count is correct
                # otherwise num_tokens_across_dp will no-longer be valid
                assert batch_desc.num_tokens == num_tokens_padded
                num_tokens_across_dp[dp_rank] = num_tokens_padded
Rémi Delacourt's avatar
Rémi Delacourt committed
1702

1703
        return cudagraph_mode, num_tokens_padded, num_tokens_across_dp
Rémi Delacourt's avatar
Rémi Delacourt committed
1704

1705

1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
class EagleProposer(SpecDecodeBaseProposer):
    def __init__(
        self,
        vllm_config: VllmConfig,
        device: torch.device,
        runner=None,
    ):
        super().__init__(
            vllm_config,
            device,
            pass_hidden_states_to_model=True,
            runner=runner,
        )


1721
1722
1723
1724
# NOTE(woosuk): Currently, the below code is not used and we always use argmax
# to sample the draft tokens. We will use this after we find a way to manage
# the draft prob tensor.
# Refer to https://github.com/vllm-project/vllm/pull/16899 for the details.
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
# FIXME(woosuk): The logic here is duplicated with the main sampling code.
# We should refactor this to reuse the same sampling implementation.
def compute_probs_and_sample_next_token(
    logits: torch.Tensor,
    sampling_metadata: SamplingMetadata,
) -> tuple[torch.Tensor, torch.Tensor]:
    if sampling_metadata.all_greedy:
        # For greedy requests, draft_probs is not used in rejection sampling.
        # Therefore, we can just return the logits.
        probs = logits
        next_token_ids = logits.argmax(dim=-1)
        return next_token_ids, probs

1738
1739
1740
1741
1742
1743
1744
1745
1746
    assert sampling_metadata.temperature is not None

    # Use epsilon comparison to detect greedy sampling (temperature ~ 0.0)
    # consistent with sampler.py's _SAMPLING_EPS threshold
    temperature = sampling_metadata.temperature
    # Avoid division by zero if there are greedy requests.
    if not sampling_metadata.all_random:
        is_greedy = temperature < _SAMPLING_EPS
        temperature = torch.where(is_greedy, 1.0, temperature)
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
    logits.div_(temperature.view(-1, 1))
    probs = logits.softmax(dim=-1, dtype=torch.float32)

    # NOTE(woosuk): Currently, we ignore most of the sampling parameters in
    # generating the draft tokens. We only use the temperature. While this
    # could degrade the acceptance rate, it does not affect the distribution
    # of the generated tokens after rejection sampling.

    # TODO(woosuk): Consider seeds.
    q = torch.empty_like(probs)
    q.exponential_()
1758
1759
1760
    # NOTE(woosuk): We shouldn't use `probs.div_(q)` because the draft_probs
    # will be used later for rejection sampling.
    next_token_ids = probs.div(q).argmax(dim=-1).view(-1)
1761
1762
    if not sampling_metadata.all_random:
        greedy_token_ids = probs.argmax(dim=-1)
1763
        next_token_ids = torch.where(is_greedy, greedy_token_ids, next_token_ids)
1764
    return next_token_ids, probs