eagle.py 62.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
import ast
from dataclasses import replace
5
from importlib.util import find_spec
6

7
import numpy as np
8
9
10
import torch
import torch.nn as nn

11
12
13
14
15
from vllm.config import (
    CUDAGraphMode,
    VllmConfig,
    get_layers_from_vllm_config,
)
16
from vllm.distributed.parallel_state import get_pp_group
17
from vllm.forward_context import set_forward_context
18
from vllm.logger import init_logger
19
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
20
from vllm.model_executor.model_loader import get_model
21
from vllm.model_executor.models import supports_multimodal
22
from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache
23
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
24
from vllm.multimodal import MULTIMODAL_REGISTRY
25
from vllm.platforms import current_platform
26
from vllm.triton_utils import triton
27
from vllm.utils.platform_utils import is_pin_memory_available
28
29
30
31
from vllm.v1.attention.backend import (
    AttentionMetadataBuilder,
    CommonAttentionMetadata,
)
32
from vllm.v1.attention.backends.registry import AttentionBackendEnum
33
34
35
36
from vllm.v1.attention.backends.tree_attn import (
    TreeAttentionMetadata,
    TreeAttentionMetadataBuilder,
)
37
from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata
38
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
39
from vllm.v1.kv_cache_interface import KVCacheConfig
40
from vllm.v1.sample.metadata import SamplingMetadata
41
from vllm.v1.sample.sampler import _SAMPLING_EPS
42
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
43
44
45
46
from vllm.v1.spec_decode.utils import (
    eagle_prepare_inputs_padded_kernel,
    eagle_prepare_next_token_padded_kernel,
)
47
from vllm.v1.utils import CpuGpuBuffer
Rémi Delacourt's avatar
Rémi Delacourt committed
48
from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
49
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
50

51
52
logger = init_logger(__name__)

53
54
PADDING_SLOT_ID = -1

55

56
class SpecDecodeBaseProposer:
57
58
59
60
    def __init__(
        self,
        vllm_config: VllmConfig,
        device: torch.device,
61
        pass_hidden_states_to_model: bool,
Jiayi Yao's avatar
Jiayi Yao committed
62
        runner=None,
63
64
    ):
        self.vllm_config = vllm_config
65
        self.speculative_config = vllm_config.speculative_config
66
        assert self.speculative_config is not None
67
68
        self.draft_model_config = self.speculative_config.draft_model_config
        self.method = self.speculative_config.method
69
        self.pass_hidden_states_to_model = pass_hidden_states_to_model
70

Jiayi Yao's avatar
Jiayi Yao committed
71
        self.runner = runner
72
        self.device = device
73
        self.dtype = vllm_config.model_config.dtype
74
        self.max_model_len = vllm_config.model_config.max_model_len
Rémi Delacourt's avatar
Rémi Delacourt committed
75
        self.dp_rank = vllm_config.parallel_config.data_parallel_rank
76
        self.num_speculative_tokens = self.speculative_config.num_speculative_tokens
77
78
79
80
81
        # The drafter can get longer sequences than the target model.
        max_batch_size = vllm_config.scheduler_config.max_num_seqs
        self.max_num_tokens = (
            vllm_config.scheduler_config.max_num_batched_tokens + max_batch_size
        )
82
        self.token_arange_np = np.arange(self.max_num_tokens)
83
84
85
86
        # We need to get the hidden size from the draft model config because
        # the draft model's hidden size can be different from the target model's
        # hidden size (e.g., Llama 3.3 70B).
        self.hidden_size = self.draft_model_config.get_hidden_size()
87
        self.inputs_embeds_size = self.draft_model_config.get_inputs_embeds_size()
88

89
90
91
        # Multi-modal data support
        self.mm_registry = MULTIMODAL_REGISTRY
        self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
92
            vllm_config.model_config
93
        )
94

95
96
        self.attn_metadata_builder: AttentionMetadataBuilder | None = None
        self.draft_indexer_metadata_builder: AttentionMetadataBuilder | None = None
97
98
        self.attn_layer_names: list[str] = []
        self.indexer_layer_names: list[str] = []
99
100
101
        self.eagle3_use_aux_hidden_state: bool = (
            self._get_eagle3_use_aux_hidden_state_from_config()
        )
102

103
        self.compilation_config = self.vllm_config.compilation_config
104
105
106
107
108
109

        # Cudagraph dispatcher for PIECEWISE-only dispatching in eagle.
        # Keys are initialized later via initialize_cudagraph_keys() called from
        # gpu_model_runner._check_and_update_cudagraph_mode after
        # adjust_cudagraph_sizes_for_spec_decode is called.
        self.cudagraph_dispatcher = CudagraphDispatcher(self.vllm_config)
110

111
        # persistent buffers for cuda graph
112
113
114
        self.input_ids = torch.zeros(
            self.max_num_tokens, dtype=torch.int32, device=device
        )
115
116
117
        # Use draft model's M-RoPE setting, not target model's
        # Draft models may be text-only even if target is multimodal
        self.uses_mrope = self.draft_model_config.uses_mrope
118
119
        self.uses_xdrope_dim = self.vllm_config.model_config.uses_xdrope_dim
        self.draft_uses_xdrope_dim = self.draft_model_config.uses_xdrope_dim
120
        if self.uses_mrope:
121
122
123
124
125
126
127
128
129
130
            # NOTE: `mrope_positions` is implemented with one additional dummy
            # position on purpose to make it non-contiguous so that it can work
            # with torch compile.
            # See detailed explanation in https://github.com/vllm-project/vllm/pull/12128#discussion_r1926431923

            # NOTE: When M-RoPE is enabled, position ids are 3D regardless of
            # the modality of inputs. For text-only inputs, each dimension has
            # identical position IDs, making M-RoPE functionally equivalent to
            # 1D-RoPE.
            # See page 5 of https://arxiv.org/abs/2409.12191
131
            self.mrope_positions = torch.zeros(
132
                (3, self.max_num_tokens + 1), dtype=torch.int64, device=device
133
            )
134
135
136
137
138
139
        elif self.uses_xdrope_dim > 0 and self.draft_uses_xdrope_dim > 0:
            self.xdrope_positions = torch.zeros(
                (self.uses_xdrope_dim, self.max_num_tokens + 1),
                dtype=torch.int64,
                device=device,
            )
140
141
        else:
            # RoPE need (max_num_tokens,)
142
143
144
            self.positions = torch.zeros(
                self.max_num_tokens, dtype=torch.int64, device=device
            )
145
        self.hidden_states = torch.zeros(
146
147
            (self.max_num_tokens, self.hidden_size), dtype=self.dtype, device=device
        )
148

149
150
151
        # We need +1 here because the arange is used to set query_start_loc,
        # which has one more element than batch_size.
        max_num_slots_for_arange = max(max_batch_size + 1, self.max_num_tokens)
152
153
154
        self.arange = torch.arange(
            max_num_slots_for_arange, device=device, dtype=torch.int32
        )
155

156
        self.inputs_embeds = torch.zeros(
157
158
159
            (self.max_num_tokens, self.inputs_embeds_size),
            dtype=self.dtype,
            device=device,
160
        )
161

162
163
164
165
166
        self.backup_next_token_ids = CpuGpuBuffer(
            max_batch_size,
            dtype=torch.int32,
            pin_memory=is_pin_memory_available(),
            device=device,
167
168
            with_numpy=True,
        )
169

170
171
172
173
        self._slot_mapping_buffer = torch.zeros(
            self.max_num_tokens, dtype=torch.int64, device=device
        )

174
        # Determine allowed attention backends once during initialization.
175
        self.allowed_attn_types: tuple | None = None
176
        if current_platform.is_rocm():
177
178
179
180
181
182
            from vllm.v1.attention.backends.rocm_attn import RocmAttentionMetadata

            rocm_types = [
                TritonAttentionMetadata,
                RocmAttentionMetadata,
            ]
183
184
185
186
            # ROCM_AITER_FA is an optional backend
            if find_spec(
                AttentionBackendEnum.ROCM_AITER_FA.get_path(include_classname=False)
            ):
187
                from vllm.v1.attention.backends.rocm_aiter_fa import (
188
189
190
                    AiterFlashAttentionMetadata,
                )

191
                rocm_types.append(AiterFlashAttentionMetadata)
192
193

            # TRITON_MLA backend support for MLA models (e.g., DeepSeek)
194
195
196
            from vllm.model_executor.layers.attention.mla_attention import (
                MLACommonMetadata,
            )
197
198
199

            rocm_types.append(MLACommonMetadata)

200
201
202
203
204
            # FlexAttention backend support
            from vllm.v1.attention.backends.flex_attention import FlexAttentionMetadata

            rocm_types.append(FlexAttentionMetadata)

205
206
            self.allowed_attn_types = tuple(rocm_types)

207
208
        # Parse the speculative token tree.
        spec_token_tree = self.speculative_config.speculative_token_tree
209
        self.tree_choices: list[tuple[int, ...]] = ast.literal_eval(spec_token_tree)
210
211
212
213
214
215
216
217
        tree_depth = len(self.tree_choices[-1])
        # Precompute per-level properties of the tree.
        num_drafts_per_level = [0] * tree_depth
        for node in self.tree_choices:
            num_drafts_per_level[len(node) - 1] += 1
        self.cu_drafts_per_level = [num_drafts_per_level[0]]
        self.child_drafts_per_level = [num_drafts_per_level[0]]
        for level in range(1, tree_depth):
218
219
220
221
222
223
            self.cu_drafts_per_level.append(
                self.cu_drafts_per_level[-1] + num_drafts_per_level[level]
            )
            self.child_drafts_per_level.append(
                num_drafts_per_level[level] // num_drafts_per_level[level - 1]
            )
224
225
        # Precompute draft position offsets in flattened tree.
        self.tree_draft_pos_offsets = torch.arange(
226
            1, len(self.tree_choices) + 1, device=device, dtype=torch.int32
227
228
        ).repeat(max_batch_size, 1)

229
230
231
    def _get_positions(self, num_tokens: int):
        if self.uses_mrope:
            return self.mrope_positions[:, :num_tokens]
232
233
        if self.uses_xdrope_dim > 0 and self.draft_uses_xdrope_dim > 0:
            return self.xdrope_positions[:, :num_tokens]
234
235
236
237
238
        return self.positions[:num_tokens]

    def _set_positions(self, num_tokens: int, positions: torch.Tensor):
        if self.uses_mrope:
            self.mrope_positions[:, :num_tokens] = positions
239
240
        elif self.uses_xdrope_dim > 0 and self.draft_uses_xdrope_dim > 0:
            self.xdrope_positions[:, :num_tokens] = positions
241
        else:
242
243
244
245
246
            # Convert M-RoPE positions if target model uses M-RoPE
            # but draft doesn't, For text inputs, all M-RoPE
            # dimensions are identical
            if self.vllm_config.model_config.uses_mrope:
                positions = positions[0]
247
248
            self.positions[:num_tokens] = positions

249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
    def _get_slot_mapping(
        self,
        num_tokens: int,
        slot_mapping: torch.Tensor | None = None,
    ) -> dict[str, torch.Tensor]:
        """Return slot_mapping dict for EAGLE layers.

        If slot_mapping is provided, copies it into the buffer first.
        """
        if slot_mapping is not None:
            num_actual = slot_mapping.shape[0]
            self._slot_mapping_buffer[:num_actual].copy_(slot_mapping)
            if num_tokens > num_actual:
                self._slot_mapping_buffer[num_actual:num_tokens].fill_(PADDING_SLOT_ID)

        view = self._slot_mapping_buffer[:num_tokens]
        return {name: view for name in self.attn_layer_names + self.indexer_layer_names}

267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
    def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode) -> None:
        """Initialize cudagraph dispatcher keys for eagle.

        Eagle only supports PIECEWISE cudagraphs (via mixed_mode).
        This should be called after adjust_cudagraph_sizes_for_spec_decode.
        """
        if (
            not self.speculative_config.enforce_eager
            and cudagraph_mode.mixed_mode()
            in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL]
        ):
            eagle_cudagraph_mode = CUDAGraphMode.PIECEWISE
        else:
            eagle_cudagraph_mode = CUDAGraphMode.NONE

        self.cudagraph_dispatcher.initialize_cudagraph_keys(eagle_cudagraph_mode)

284
285
286
287
    def propose(
        self,
        # [num_tokens]
        target_token_ids: torch.Tensor,
288
        # [num_tokens] or [3, num_tokens] when M-RoPE is enabled
289
290
291
292
293
        target_positions: torch.Tensor,
        # [num_tokens, hidden_size]
        target_hidden_states: torch.Tensor,
        # [batch_size]
        next_token_ids: torch.Tensor,
294
        last_token_indices: torch.Tensor | None,
295
        common_attn_metadata: CommonAttentionMetadata,
296
        sampling_metadata: SamplingMetadata,
297
        mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None,
298
        num_rejected_tokens_gpu: torch.Tensor | None = None,
299
300
301
        slot_mappings: dict[str, torch.Tensor]
        | list[dict[str, torch.Tensor]]
        | None = None,
302
    ) -> torch.Tensor:
303
        batch_size = common_attn_metadata.batch_size()
304

305
306
307
        if self.method == "eagle3":
            assert isinstance(self.model, Eagle3LlamaForCausalLM)
            target_hidden_states = self.model.combine_hidden_states(
308
309
                target_hidden_states
            )
310
            assert target_hidden_states.shape[-1] == self.hidden_size
311
312
313
314
315
316
317
318
319
320
321

        num_tokens, last_token_indices, common_attn_metadata = (
            self.set_inputs_first_pass(
                target_token_ids=target_token_ids,
                next_token_ids=next_token_ids,
                target_positions=target_positions,
                last_token_indices=last_token_indices,
                cad=common_attn_metadata,
                num_rejected_tokens_gpu=num_rejected_tokens_gpu,
            )
        )
322

323
        assert self.runner is not None
Jiayi Yao's avatar
Jiayi Yao committed
324

325
326
327
328
329
        if self.attn_metadata_builder is None:
            attn_metadata_builder = self._get_attention_metadata_builder()
        else:
            attn_metadata_builder = self.attn_metadata_builder

330
        attn_metadata = attn_metadata_builder.build_for_drafting(
331
332
            common_attn_metadata=common_attn_metadata, draft_index=0
        )
333
334
335
336
337
338
        # FIXME: support hybrid kv for draft model (remove separate indexer)
        if self.draft_indexer_metadata_builder:
            draft_indexer_metadata = (
                self.draft_indexer_metadata_builder.build_for_drafting(
                    common_attn_metadata=common_attn_metadata,
                    draft_index=0,
339
340
                )
            )
341
342
        else:
            draft_indexer_metadata = None
343
344
345
346
347
        # At this moment, we assume all eagle layers belong to the same KV
        # cache group, thus using the same attention metadata.
        per_layer_attn_metadata = {}
        for layer_name in self.attn_layer_names:
            per_layer_attn_metadata[layer_name] = attn_metadata
348

349
350
351
352
        for layer_name in self.indexer_layer_names:
            assert draft_indexer_metadata is not None
            per_layer_attn_metadata[layer_name] = draft_indexer_metadata

Rémi Delacourt's avatar
Rémi Delacourt committed
353
        num_tokens_dp_padded, num_tokens_across_dp = self._pad_batch_across_dp(
354
            num_tokens_unpadded=num_tokens, num_tokens_padded=num_tokens
Rémi Delacourt's avatar
Rémi Delacourt committed
355
356
        )

357
358
359
360
        cudagraph_runtime_mode, batch_desc = self.cudagraph_dispatcher.dispatch(
            num_tokens_dp_padded
        )
        num_input_tokens = batch_desc.num_tokens
Rémi Delacourt's avatar
Rémi Delacourt committed
361
362
363
        if num_tokens_across_dp is not None:
            num_tokens_across_dp[self.dp_rank] = num_input_tokens

364
365
366
367
        if self.pass_hidden_states_to_model:
            # target_hidden_states and self.hidden_states can have different
            # hidden dims. E.g. large target model and small draft model.
            self.hidden_states[:num_tokens] = target_hidden_states
368
369
370
371

        if self.supports_mm_inputs:
            mm_embeds, is_mm_embed = mm_embed_inputs or (None, None)

372
            self.inputs_embeds[:num_tokens] = self.model.embed_input_ids(
373
374
375
                self.input_ids[:num_tokens],
                multimodal_embeddings=mm_embeds,
                is_multimodal=is_mm_embed,
376
            )
377

378
            input_ids = None
379
            inputs_embeds = self.inputs_embeds[:num_input_tokens]
380
381
        else:
            input_ids = self.input_ids[:num_input_tokens]
382
            inputs_embeds = None
383

384
385
386
387
388
389
390
391
        model_kwargs = {
            "input_ids": input_ids,
            "positions": self._get_positions(num_input_tokens),
            "inputs_embeds": inputs_embeds,
        }
        if self.pass_hidden_states_to_model:
            model_kwargs["hidden_states"] = self.hidden_states[:num_input_tokens]

392
        with set_forward_context(
393
394
395
            per_layer_attn_metadata,
            self.vllm_config,
            num_tokens=num_input_tokens,
Rémi Delacourt's avatar
Rémi Delacourt committed
396
            num_tokens_across_dp=num_tokens_across_dp,
397
            cudagraph_runtime_mode=cudagraph_runtime_mode,
398
399
400
            slot_mapping=self._get_slot_mapping(
                num_input_tokens, common_attn_metadata.slot_mapping
            ),
401
        ):
402
403
            ret_hidden_states = self.model(**model_kwargs)
            if not self.model_returns_tuple():
Jiayi Yao's avatar
Jiayi Yao committed
404
                last_hidden_states = ret_hidden_states
405
                hidden_states = last_hidden_states
Jiayi Yao's avatar
Jiayi Yao committed
406
407
            else:
                last_hidden_states, hidden_states = ret_hidden_states
408

409
        sample_hidden_states = last_hidden_states[last_token_indices]
410
        logits = self.model.compute_logits(sample_hidden_states)
411
412
413
414
415
416

        # Early exit if there is only one draft token to be generated.
        if self.num_speculative_tokens == 1:
            draft_token_ids = logits.argmax(dim=-1)
            return draft_token_ids.view(-1, 1)

417
        if self.uses_mrope:
418
            positions = self.mrope_positions[:, last_token_indices]
419
        else:
420
            positions = self.positions[last_token_indices]
421
422
423
424
425
426
        if self.method in (
            "deepseek_mtp",
            "ernie_mtp",
            "longcat_flash_mtp",
            "pangu_ultra_moe_mtp",
        ):
XuruiYang's avatar
XuruiYang committed
427
428
429
            hidden_states = self.hidden_states[last_token_indices]
        else:
            hidden_states = hidden_states[last_token_indices]
430
431
432

        if isinstance(attn_metadata, TreeAttentionMetadata):
            # Draft using tree attention.
433
434
435
436
437
438
            draft_token_ids_list = self.propose_tree(
                batch_size=batch_size,
                logits=logits,
                positions=positions,
                hidden_states=hidden_states,
                common_attn_metadata=common_attn_metadata,
439
                slot_mappings=slot_mappings,
440
441
442
443
            )
            # [batch_size, num_tree_tokens]
            return torch.cat(draft_token_ids_list, dim=1)

444
        draft_token_ids = logits.argmax(dim=-1)
445

446
447
448
        if self.allowed_attn_types is not None and not isinstance(
            attn_metadata, self.allowed_attn_types
        ):
449
450
451
452
            raise ValueError(
                f"Unsupported attention metadata type for speculative "
                "decoding with num_speculative_tokens > 1: "
                f"{type(attn_metadata)}. Supported types are: "
453
454
                f"{self.allowed_attn_types}"
            )
455

456
457
458
        # Generate the remaining draft tokens.
        draft_token_ids_list = [draft_token_ids]

Rémi Delacourt's avatar
Rémi Delacourt committed
459
        batch_size_dp_padded, batch_size_across_dp = self._pad_batch_across_dp(
460
            num_tokens_unpadded=batch_size, num_tokens_padded=batch_size
Rémi Delacourt's avatar
Rémi Delacourt committed
461
462
        )

463
464
465
466
        cudagraph_runtime_mode, batch_desc = self.cudagraph_dispatcher.dispatch(
            batch_size_dp_padded
        )
        input_batch_size = batch_desc.num_tokens
Rémi Delacourt's avatar
Rémi Delacourt committed
467
468
        if batch_size_across_dp is not None:
            batch_size_across_dp[self.dp_rank] = input_batch_size
469
470
471

        common_attn_metadata.num_actual_tokens = batch_size
        common_attn_metadata.max_query_len = 1
472
        common_attn_metadata.query_start_loc = self.arange[: batch_size + 1]
473
        common_attn_metadata.query_start_loc_cpu = torch.from_numpy(
474
475
            self.token_arange_np[: batch_size + 1]
        ).clone()
476
477
478
479
480
481
482
483
484
485
486

        # In padded drafter batch, we need to adjust the sequence lengths
        # to remove the "padding" (i.e. rejected tokens).
        # Only apply this adjustment when we have rejected tokens
        # (i.e., not the first proposal).
        if self.num_speculative_tokens > 1 and num_rejected_tokens_gpu is not None:
            common_attn_metadata.seq_lens -= num_rejected_tokens_gpu
            # Invalidate the CPU-side shadows to avoid H<>D sync.
            common_attn_metadata._seq_lens_cpu = None
            common_attn_metadata._num_computed_tokens_cpu = None

487
        for token_index in range(self.num_speculative_tokens - 1):
488
            # Update the inputs.
489
490
491
            # cast to int32 is crucial when eagle model is compiled.
            # tensor.argmax() returns int64 by default.
            input_ids = draft_token_ids_list[-1].int()
492
493
494
495
496
497
498
499
500
501
502
503
504
            if self.uses_mrope:
                positions += 1
                # NOTE(woosuk): We should handle the case where the draft model
                # generates tokens beyond the max model length.
                # Since it is complex to remove such requests from the batch,
                # we keep them in the batch but adjust the position ids
                # and slot mappings to avoid the
                # out-of-range access during the model execution.
                # The draft tokens generated with this adjustment
                # should be ignored.
                exceeds_max_model_len = positions[0] >= self.max_model_len
                # Mask out the position ids that exceed the max model length.
                # Otherwise, we may get out-of-range error in RoPE.
505
506
507
508
509
                clamped_positions = torch.where(
                    exceeds_max_model_len.unsqueeze(0),
                    torch.zeros_like(positions),
                    positions,
                )
510
511
512
            else:
                positions += 1
                exceeds_max_model_len = positions >= self.max_model_len
513
                clamped_positions = torch.where(exceeds_max_model_len, 0, positions)
514
515
516
            # For data integrity when async scheduling, we shouldn't use in place
            # operations in case they are modified in next step's `prepare_input`
            # of main model.
517
            # Increment the sequence lengths.
518
            common_attn_metadata.seq_lens += 1
519
520
            # For the requests that exceed the max model length, we set the
            # sequence length to 1 to minimize their overheads in attention.
521
            common_attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1)
522
523
524
525
526
527
            # Increment the maximum sequence length. We increment max_seq_len
            # unconditionally even though some seq_lens may have been capped above,
            # as max_seq_len serves as an upper bound for sequence lengths.
            common_attn_metadata.max_seq_len = min(
                common_attn_metadata.max_seq_len + 1, self.max_model_len
            )
528

529
530
531
532
533
534
            # Also update the CPU-side shadow; NOTE: this is hacky and should be
            # removed in when common_attn_metadata.seq_lens_cpu is deprecated.
            if common_attn_metadata._seq_lens_cpu is not None:
                common_attn_metadata._seq_lens_cpu += 1
            if common_attn_metadata._num_computed_tokens_cpu is not None:
                common_attn_metadata._num_computed_tokens_cpu += 1
535

536
            # Compute the slot mapping.
537
            block_size = attn_metadata_builder.kv_cache_spec.block_size
538
539
            if self.uses_mrope:
                # all dimensions of positions are the same
540
                block_numbers = clamped_positions[0] // block_size
541
            else:
542
                block_numbers = clamped_positions // block_size
543
            block_ids = common_attn_metadata.block_table_tensor.gather(
544
545
                dim=1, index=block_numbers.view(-1, 1)
            )
546
            block_ids = block_ids.view(-1)
547
548
            if self.uses_mrope:
                common_attn_metadata.slot_mapping = (
549
                    block_ids * block_size + clamped_positions[0] % block_size
550
                )
551
552
            else:
                common_attn_metadata.slot_mapping = (
553
                    block_ids * block_size + clamped_positions % block_size
554
                )
555
556
557
            # Mask out the slot mappings that exceed the max model length.
            # Otherwise, the KV cache will be inadvertently updated with the
            # padding tokens.
558
            common_attn_metadata.slot_mapping.masked_fill_(
559
560
                exceeds_max_model_len, PADDING_SLOT_ID
            )
561
562

            # Rebuild attention metadata
563
            attn_metadata = attn_metadata_builder.build_for_drafting(  # type: ignore
564
565
                common_attn_metadata=common_attn_metadata, draft_index=token_index + 1
            )
566
567
            for layer_name in self.attn_layer_names:
                per_layer_attn_metadata[layer_name] = attn_metadata
568

569
570
            # copy inputs to buffer for cudagraph
            self.input_ids[:batch_size] = input_ids
571
            self._set_positions(batch_size, clamped_positions)
572
            self.hidden_states[:batch_size] = hidden_states
573
            if self.supports_mm_inputs:
574
                self.inputs_embeds[:batch_size] = self.model.embed_input_ids(input_ids)
575

576
                input_ids = None
577
                inputs_embeds = self.inputs_embeds[:input_batch_size]
578
579
            else:
                input_ids = self.input_ids[:input_batch_size]
580
                inputs_embeds = None
581

582
            # Run the model.
583
584
585
586
587
588
589
590
            model_kwargs = {
                "input_ids": input_ids,
                "positions": self._get_positions(input_batch_size),
                "inputs_embeds": inputs_embeds,
            }
            if self.pass_hidden_states_to_model:
                model_kwargs["hidden_states"] = self.hidden_states[:input_batch_size]

591
            with set_forward_context(
592
593
594
                per_layer_attn_metadata,
                self.vllm_config,
                num_tokens=input_batch_size,
Rémi Delacourt's avatar
Rémi Delacourt committed
595
                num_tokens_across_dp=batch_size_across_dp,
596
                cudagraph_runtime_mode=cudagraph_runtime_mode,
597
598
599
                slot_mapping=self._get_slot_mapping(
                    input_batch_size, common_attn_metadata.slot_mapping
                ),
600
            ):
601
602
                ret_hidden_states = self.model(**model_kwargs)
                if not self.model_returns_tuple():
603
604
605
606
                    last_hidden_states = ret_hidden_states
                    hidden_states = ret_hidden_states
                else:
                    last_hidden_states, hidden_states = ret_hidden_states
607

608
            hidden_states = hidden_states[:batch_size]
609
            logits = self.model.compute_logits(last_hidden_states[:batch_size])
610
            draft_token_ids = logits.argmax(dim=-1)
611
612
613
614
            draft_token_ids_list.append(draft_token_ids)

        # [batch_size, num_speculative_tokens]
        draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
615
        return draft_token_ids
616

617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
    def set_inputs_first_pass(
        self,
        target_token_ids: torch.Tensor,
        next_token_ids: torch.Tensor,
        target_positions: torch.Tensor,
        last_token_indices: torch.Tensor | None,
        cad: CommonAttentionMetadata,
        num_rejected_tokens_gpu: torch.Tensor | None,
    ) -> tuple[int, torch.Tensor, CommonAttentionMetadata]:
        if last_token_indices is None:
            last_token_indices = cad.query_start_loc[1:] - 1

        num_tokens = target_token_ids.shape[0]
        # Shift the input ids by one token.
        # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
        self.input_ids[: num_tokens - 1] = target_token_ids[1:]
        # Replace the last token with the next token.
        # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
        self.input_ids[last_token_indices] = next_token_ids

        # copy inputs to buffer for cudagraph
638
639
        if self.uses_xdrope_dim > 0 and self.draft_uses_xdrope_dim == 0:
            target_positions = target_positions[0]
640
641
642
643
644
645
646
        self._set_positions(num_tokens, target_positions)

        return num_tokens, last_token_indices, cad

    def model_returns_tuple(self) -> bool:
        return self.method not in ("mtp", "draft_model")

647
    def prepare_next_token_ids_cpu(
648
        self,
649
        sampled_token_ids: list[list[int]],
650
651
652
653
        requests: dict[str, CachedRequestState],
        gpu_input_batch: InputBatch,
        num_scheduled_tokens: dict[str, int],
    ) -> torch.Tensor:
654
655
656
657
658
659
660
661
662
663
        """
        This function is used to prepare the inputs for speculative decoding.
        It calculates the next token ids for each request based on the sampled
        token ids from the CPU. If a request has no sampled token ids (e.g.,
        during the initial decoding steps), it falls back to using the request
        state to get the next token id.
        """
        req_ids = gpu_input_batch.req_ids
        next_token_ids: list[int] = []
        for i, token_ids in enumerate(sampled_token_ids):
664
            if token_ids:
665
666
667
668
669
670
671
                # Common case.
                next_token_id = token_ids[-1]
            else:
                # Partial prefill (rare case).
                # Get the next token id from the request state.
                req_id = req_ids[i]
                req_state = requests[req_id]
672
                seq_len = req_state.num_computed_tokens + num_scheduled_tokens[req_id]
673
674
                next_token_id = req_state.get_token_id(seq_len)
            next_token_ids.append(next_token_id)
675
        next_token_ids = torch.tensor(
676
677
            next_token_ids, dtype=torch.int32, device=self.input_ids.device
        )
678
        return next_token_ids
679

680
681
682
683
684
685
    def prepare_next_token_ids_padded(
        self,
        common_attn_metadata: CommonAttentionMetadata,
        sampled_token_ids: torch.Tensor,
        requests: dict[str, CachedRequestState],
        gpu_input_batch: InputBatch,
686
        discard_request_mask: torch.Tensor,
687
    ) -> tuple[torch.Tensor, torch.Tensor]:
688
689
690
691
        """
        This function is used to prepare the inputs for speculative decoding.
        It calculates the next token ids and the number of valid sampled tokens
        for each request, considering the "discarded" requests whose next token
692
693
        is not sampled and comes from `request.get_token_id()` instead. This is denoted
        the "backup" token id. It also counts rejected tokens via `sampled_token_ids`.
694
695
696
        """
        # Precompute get_token_id for when there is no valid next token
        num_reqs = gpu_input_batch.num_reqs
697
698
699
700
701
702
        self.backup_next_token_ids.np[:num_reqs] = np.array(
            [
                requests[gpu_input_batch.req_ids[i]].get_token_id(
                    common_attn_metadata.seq_lens_cpu[i].item()
                )
                for i in range(num_reqs)
703
704
            ],
            dtype=np.int32,
705
        )
706
        self.backup_next_token_ids.copy_to_gpu(num_reqs)
707
        backup_tokens_gpu = self.backup_next_token_ids.gpu
708

709
710
        batch_size, num_tokens = sampled_token_ids.shape
        device = sampled_token_ids.device
711

712
713
        assert discard_request_mask.dtype == torch.bool
        assert backup_tokens_gpu.dtype == torch.int32
714

715
716
        next_token_ids = torch.empty(batch_size, dtype=torch.int32, device=device)
        valid_sampled_tokens_count = next_token_ids.new_empty(batch_size)
717

718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
        # Kernel grid: one program per request (row)
        grid = (batch_size,)

        # Find the next power of 2 for block sizes
        BLOCK_SIZE_TOKENS = triton.next_power_of_2(num_tokens)
        eagle_prepare_next_token_padded_kernel[grid](
            sampled_token_ids,
            discard_request_mask,
            backup_tokens_gpu,
            next_token_ids,
            valid_sampled_tokens_count,
            gpu_input_batch.vocab_size,
            num_tokens,
            batch_size,
            sampled_token_ids.stride(0),
            BLOCK_SIZE_TOKENS=BLOCK_SIZE_TOKENS,
734
        )
735
736
737

        return next_token_ids, valid_sampled_tokens_count

738
739
740
741
742
    def prepare_inputs_padded(
        self,
        common_attn_metadata: CommonAttentionMetadata,
        spec_decode_metadata: SpecDecodeMetadata,
        valid_sampled_tokens_count: torch.Tensor,
743
    ) -> tuple[CommonAttentionMetadata, torch.Tensor, torch.Tensor]:
744
745
746
747
748
749
        """
        This function is used to prepare the inputs for speculative decoding
        It updates the common_attn_metadata for speculative decoding,
        but does not consider the rejected tokens. Instead, all tokens
        are included as inputs to the speculator, with the rejected tokens
        used as padding and filtered out later by `token_indices_to_sample`.
750
        No blocking CPU operations should be introduced in this function.
751
        """
752
753
754
755
756
        num_reqs = common_attn_metadata.num_reqs
        device = valid_sampled_tokens_count.device

        token_indices_to_sample = torch.empty(
            (num_reqs,), dtype=torch.int32, device=device
757
        )
758
759
760
        num_rejected_tokens_gpu = torch.empty(
            (num_reqs,), dtype=torch.int32, device=device
        )
761

762
763
764
765
766
767
        grid = (num_reqs,)
        eagle_prepare_inputs_padded_kernel[grid](
            spec_decode_metadata.cu_num_draft_tokens,
            valid_sampled_tokens_count,
            common_attn_metadata.query_start_loc,
            token_indices_to_sample,
768
            num_rejected_tokens_gpu,
769
            num_reqs,
770
        )
771
772

        query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
773
        new_query_len_per_req = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
774
775
776
777
778
779
780

        total_num_tokens = query_start_loc_cpu[-1].item()

        spec_common_attn_metadata = CommonAttentionMetadata(
            query_start_loc=common_attn_metadata.query_start_loc,
            seq_lens=common_attn_metadata.seq_lens,
            query_start_loc_cpu=query_start_loc_cpu,
781
782
            _seq_lens_cpu=common_attn_metadata._seq_lens_cpu,
            _num_computed_tokens_cpu=common_attn_metadata._num_computed_tokens_cpu,
783
784
785
786
787
            num_reqs=common_attn_metadata.num_reqs,
            num_actual_tokens=total_num_tokens,
            max_query_len=new_query_len_per_req.max().item(),
            max_seq_len=common_attn_metadata.seq_lens_cpu.max().item(),
            block_table_tensor=common_attn_metadata.block_table_tensor,
788
            slot_mapping=common_attn_metadata.slot_mapping[:total_num_tokens],
789
            causal=True,
790
            dcp_local_seq_lens=common_attn_metadata.dcp_local_seq_lens,
791
792
        )

793
794
795
796
797
        return (
            spec_common_attn_metadata,
            token_indices_to_sample,
            num_rejected_tokens_gpu,
        )
798

799
800
801
802
803
804
805
806
807
808
    def propose_tree(
        self,
        batch_size: int,
        # [num_tokens, vocab_size]
        logits: torch.Tensor,
        # [num_tokens]
        positions: torch.Tensor,
        # [num_tokens, hidden_size]
        hidden_states: torch.Tensor,
        common_attn_metadata: CommonAttentionMetadata,
809
810
811
        slot_mappings: dict[str, torch.Tensor]
        | list[dict[str, torch.Tensor]]
        | None = None,
812
    ) -> list[torch.Tensor]:
813
814
815
816
        tree_attn_metadata_builder = self.runner.attn_groups[0][
            0
        ].get_metadata_builder()
        assert isinstance(tree_attn_metadata_builder, TreeAttentionMetadataBuilder)
817

818
        total_num_drafts = self.cu_drafts_per_level[0]
819
820
        level_num_drafts = total_num_drafts
        # Sample a draft token for each child at the tree root level.
821
        num_children = self.child_drafts_per_level[0]
822
823
824
        if num_children == 1:
            draft_token_ids = logits.argmax(dim=-1).view(batch_size, -1)
        else:
825
826
827
            draft_token_ids = torch.topk(logits, num_children, dim=-1).indices.view(
                batch_size, -1
            )
828
829
830
831
        draft_token_ids_list = [draft_token_ids]
        draft_hidden_states = hidden_states.view(batch_size, 1, -1)

        # Initialize empty tensors for concatenation with the level outputs.
832
833
834
835
836
837
838
839
840
        tree_input_ids = torch.empty(
            0, device=self.input_ids.device, dtype=self.input_ids.dtype
        )
        tree_positions = torch.empty(
            0, device=self.positions.device, dtype=self.positions.dtype
        )
        tree_hidden_states = torch.empty(
            0, device=self.hidden_states.device, dtype=self.hidden_states.dtype
        )
841
842
        # Precompute the draft token positions.
        flattened_draft_positions = (
843
844
            positions.view(batch_size, -1) + self.tree_draft_pos_offsets[:batch_size, :]
        )
845
        tree_depth = len(self.cu_drafts_per_level)
846
        for level in range(tree_depth - 1):
847
848
            # Get draft positions for RoPE.
            draft_positions = positions + (level + 1)
849
            exceeds_max_model_len = (positions + total_num_drafts) >= self.max_model_len
850
851
            # Mask out the position ids that exceed the max model length.
            # Otherwise, we may get out-of-range error in RoPE.
852
            draft_positions = torch.where(
853
854
855
                exceeds_max_model_len,
                0,
                draft_positions,
856
857
            ).view(batch_size, -1)

858
859
            if level_num_drafts > 1:
                # Repeat the positions for each draft at this level.
860
                draft_positions = draft_positions.repeat_interleave(
861
862
                    level_num_drafts, dim=1
                )
863
864
865
866

            if num_children > 1:
                # Repeat draft hidden states for each child.
                draft_hidden_states = draft_hidden_states.repeat_interleave(
867
868
                    num_children, dim=1
                )
869
870

            # Concatenate the draft tokens, positions, and hidden states.
871
872
            tree_input_ids = torch.cat([tree_input_ids, draft_token_ids], dim=1)
            tree_positions = torch.cat([tree_positions, draft_positions], dim=1)
873
            tree_hidden_states = torch.cat(
874
875
                [tree_hidden_states, draft_hidden_states], dim=1
            )
876
877
878

            # Build new attention metadata for the next level of drafts.
            # This is necessary to support tree attention.
879
            query_len = total_num_drafts
880
881
            common_attn_metadata = replace(
                common_attn_metadata,
882
                query_start_loc=query_len * self.arange[: batch_size + 1],
883
884
885
886
887
                seq_lens=common_attn_metadata.seq_lens + level_num_drafts,
                num_actual_tokens=batch_size * query_len,
                max_query_len=query_len,
            )
            attn_metadata = tree_attn_metadata_builder.build_for_drafting(
888
                common_attn_metadata=common_attn_metadata, draft_index=level + 1
889
890
891
892
893
894
895
896
            )

            # Apply new attention metadata to all layers.
            per_layer_attn_metadata = {}
            for layer_name in self.attn_layer_names:
                per_layer_attn_metadata[layer_name] = attn_metadata

            # Consider max model length.
897
898
899
            attn_metadata.max_seq_len = min(
                attn_metadata.max_seq_len, self.max_model_len
            )
900
901
902
903
904
            # For the requests that exceed the max model length, we set the
            # sequence length to 1 to minimize their overheads in attention.
            attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1)

            # Compute the slot mapping.
905
            block_size = tree_attn_metadata_builder.kv_cache_spec.block_size
906
            query_positions = flattened_draft_positions[:, level : level + query_len]
907
            block_numbers = query_positions // block_size
908
            block_ids = attn_metadata.block_table.gather(dim=1, index=block_numbers)
909
            slot_mapping = block_ids * block_size + query_positions % block_size
910
911
912
913
914
915
916
917
918
919
920
            # Mask out the slot mappings that exceed the max model length.
            # Otherwise, the KV cache will be inadvertently updated with the
            # padding tokens.
            slot_mapping[exceeds_max_model_len] = PADDING_SLOT_ID
            attn_metadata.slot_mapping = slot_mapping.view(-1)

            # Copy inputs to buffer for cudagraph.
            num_tokens = attn_metadata.num_actual_tokens
            input_ids = tree_input_ids.view(-1)
            self.input_ids[:num_tokens] = input_ids
            self.positions[:num_tokens] = tree_positions.view(-1)
921
            self.hidden_states[:num_tokens] = tree_hidden_states.view(num_tokens, -1)
922

923
924
925
926
            cudagraph_runtime_mode, batch_desc = self.cudagraph_dispatcher.dispatch(
                num_tokens
            )
            num_input_tokens = batch_desc.num_tokens
927
            # Run the model.
928
            with set_forward_context(
929
930
931
932
                per_layer_attn_metadata,
                self.vllm_config,
                num_tokens=num_input_tokens,
                cudagraph_runtime_mode=cudagraph_runtime_mode,
933
934
935
                slot_mapping=self._get_slot_mapping(
                    num_input_tokens, attn_metadata.slot_mapping
                ),
936
            ):
937
938
939
940
941
942
943
944
945
                last_hidden_states, hidden_states = self.model(
                    input_ids=self.input_ids[:num_input_tokens],
                    positions=self.positions[:num_input_tokens],
                    hidden_states=self.hidden_states[:num_input_tokens],
                    inputs_embeds=None,
                )

            # Get the output hidden states for the draft tokens.
            draft_hidden_states = hidden_states[:num_tokens].view(
946
947
                batch_size, query_len, -1
            )[:, -level_num_drafts:]
948
            draft_last_hidden_states = last_hidden_states[:num_tokens].view(
949
950
                batch_size, query_len, -1
            )[:, -level_num_drafts:]
951
952
953

            # Get the output logits for the draft tokens.
            logits = self.model.compute_logits(
954
955
                draft_last_hidden_states.reshape(batch_size * level_num_drafts, -1)
            )
956
957
958
959
960
961

            # Sample a draft token for each child at the next tree level.
            num_children = self.child_drafts_per_level[level + 1]
            if num_children == 1:
                draft_token_ids = logits.argmax(dim=-1).view(batch_size, -1)
            else:
962
963
964
                draft_token_ids = torch.topk(logits, num_children, dim=-1).indices.view(
                    batch_size, -1
                )
965
966
967
            draft_token_ids_list.append(draft_token_ids)

            # Update the # drafts counters for the next tree level.
968
            level_num_drafts = self.cu_drafts_per_level[level + 1] - total_num_drafts
969
970
971
            total_num_drafts = self.cu_drafts_per_level[level + 1]
        return draft_token_ids_list

972
    def prepare_inputs(
973
974
        self,
        common_attn_metadata: CommonAttentionMetadata,
975
976
        sampled_token_ids: list[list[int]],
        num_draft_tokens: list[int],
977
978
    ) -> tuple[CommonAttentionMetadata, torch.Tensor]:
        """
979
        This function is used to prepare the inputs for speculative decoding.
980
981
982
983
984
985
        It updates to the common_attn_metadata to account for the rejected
        tokens (and newly sampled tokens). It also returns the token indices
        of the tokens that should be fed to the speculator.
        """
        # E.g.
        #  common_attn_metadata.query_start_loc{_cpu}:
986
        #       [0, q1, q1 + q2, q1 + q2 + q3]
987
988
989
990
991
992
        #  common_attn_metadata.seq_lens{_cpu}: [s1, s2, s3]
        #  num_rejected_tokens: [n1, n2, n3]
        # This function computes the intermediate values:
        #  num_tokens_per_req: [q1 - n1, q2 - n2, q3 - n3]
        # And returns:
        #  common_attn_metadata.query_start_loc{_cpu}:
993
        #       [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
994
        #  common_attn_metadata.seq_lens{_cpu}:
995
        #       [s1 - n1 + 1, s2 - n2 + 1, s3 - n3 + 1]
996
        #  token_indices: [0, 1, ..., q1 - n1 - 1,
997
998
        #                 q1, q1 + 1, ..., q1 + q2 - n2 - 1,
        #                 q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1]
999

1000
1001
1002
1003
        num_rejected_tokens = [
            n + 1 - len(sampled_token_ids[i]) if n > 0 else 0
            for i, n in enumerate(num_draft_tokens)
        ]
1004
        num_rejected_tokens = torch.tensor(num_rejected_tokens, dtype=torch.int32)
1005

1006
1007
        device = common_attn_metadata.query_start_loc.device
        query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
1008
        new_seq_lens_cpu = common_attn_metadata.seq_lens_cpu - num_rejected_tokens
1009
1010

        # [0, q1, q1 + q2, q1 + q2 + q3] -> [q1, q2, q3]
1011
        new_query_len_per_req = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
1012
1013
1014
1015
1016
1017
1018
1019
        # [q1, q2, q3] -> [q1 - n1, q2 - n2, q3 - n3]
        new_num_tokens_per_req = new_query_len_per_req - num_rejected_tokens
        new_num_tokens_per_req_np = new_num_tokens_per_req.numpy()

        # [q1 - n1, q2 - n2, q3 - n3] ->
        # [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
        new_query_start_loc_cpu = torch.zeros(
            query_start_loc_cpu.shape,
1020
            dtype=torch.int32,
1021
1022
            pin_memory=is_pin_memory_available(),
        )
1023
1024
1025
1026
1027
1028
1029
1030
1031
        new_query_start_loc_np = new_query_start_loc_cpu.numpy()
        np.cumsum(new_num_tokens_per_req_np, out=new_query_start_loc_np[1:])

        total_num_tokens = new_query_start_loc_np[-1]
        # Example assuming num_tokens_per_req_np = [2, 4, 3]
        # this implies that `new_query_start_locs` is:
        # [0, 2, 6, 9] ->
        # [0, 0, 2, 2, 2, 2, 6, 6, 6]
        #  _r1_  ____r2____  ___r3__
1032
1033
1034
        new_query_start_locs_expanded = np.repeat(
            new_query_start_loc_np[:-1], new_num_tokens_per_req_np
        )
1035
1036
1037
        # [0, 1, 2, 3, 4, 5, 6, 7, 8] ->
        # [0, 1, 0, 1, 2, 3, 0, 1, 2]
        #  _r1_  ____r2____  ___r3__
1038
1039
1040
        token_offests = (
            self.token_arange_np[:total_num_tokens] - new_query_start_locs_expanded
        )
1041
1042
1043
1044
1045
1046

        # Expand starting positions to match token pattern
        # [0, q1, q1 + q2] ->
        # [0, 0, q1, q1, q1, q1, q1 + q2, q1 + q2, q1 + q2]
        #  _r1_  _____r2_______  ___________r3____________
        old_query_start_locs_expanded = np.repeat(
1047
1048
            query_start_loc_cpu[:-1].numpy(), new_num_tokens_per_req_np
        )
1049
        # Final token indices are:
1050
1051
1052
        # [0, 1,                                // req 1
        #  q1 + 0, q1 + 1, q1 + 2, q1 + 3,       // req 2
        #  q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3
1053
        token_indices_np = token_offests + old_query_start_locs_expanded
1054
        token_indices = torch.from_numpy(token_indices_np).to(device, non_blocking=True)
1055
1056

        spec_common_attn_metadata = CommonAttentionMetadata(
1057
            query_start_loc=new_query_start_loc_cpu.to(device, non_blocking=True),
1058
1059
            seq_lens=new_seq_lens_cpu.to(device, non_blocking=True),
            query_start_loc_cpu=new_query_start_loc_cpu,
1060
1061
            _seq_lens_cpu=new_seq_lens_cpu,
            _num_computed_tokens_cpu=common_attn_metadata._num_computed_tokens_cpu,
1062
1063
1064
            num_reqs=common_attn_metadata.num_reqs,
            num_actual_tokens=total_num_tokens,
            max_query_len=new_query_len_per_req.max().item(),
1065
            max_seq_len=new_seq_lens_cpu.max().item(),
1066
1067
            block_table_tensor=common_attn_metadata.block_table_tensor,
            slot_mapping=common_attn_metadata.slot_mapping[token_indices],
1068
            causal=True,
1069
            dcp_local_seq_lens=common_attn_metadata.dcp_local_seq_lens,
1070
        )
1071
1072

        return spec_common_attn_metadata, token_indices
1073

1074
    def get_model_name(self, model: nn.Module) -> str:
1075
        if hasattr(model, "module"):  # multi-GPU
1076
1077
1078
            model = model.module
        return model.__class__.__name__

1079
    def load_model(self, target_model: nn.Module) -> None:
1080
        draft_model_config = self.vllm_config.speculative_config.draft_model_config
1081
        target_attn_layer_names = set(
1082
            get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase).keys()
1083
        )
1084
1085
        # FIXME: support hybrid kv for draft model
        target_indexer_layer_names = set(
1086
1087
1088
1089
            get_layers_from_vllm_config(
                self.vllm_config, DeepseekV32IndexerCache
            ).keys()
        )
1090

1091
        from vllm.compilation.backends import set_model_tag
1092

1093
        with set_model_tag("eagle_head"):
1094
1095
1096
            self.model = get_model(
                vllm_config=self.vllm_config, model_config=draft_model_config
            )
1097

1098
        draft_attn_layer_names = (
1099
            get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase).keys()
1100
1101
1102
1103
1104
1105
            - target_attn_layer_names
        )
        indexer_layers = get_layers_from_vllm_config(
            self.vllm_config, DeepseekV32IndexerCache
        )
        draft_indexer_layer_names = indexer_layers.keys() - target_indexer_layer_names
1106
        self.attn_layer_names = list(draft_attn_layer_names - draft_indexer_layer_names)
1107
1108
1109
1110
1111
        self.indexer_layer_names = list(draft_indexer_layer_names)

        if self.indexer_layer_names:
            first_layer = self.indexer_layer_names[0]
            self.draft_indexer_metadata_builder = (
1112
1113
1114
                indexer_layers[first_layer]
                .get_attn_backend()
                .get_builder_cls()(
1115
                    indexer_layers[first_layer].get_kv_cache_spec(self.vllm_config),
1116
1117
1118
                    self.indexer_layer_names,
                    self.vllm_config,
                    self.device,
1119
1120
                )
            )
1121
1122
        else:
            self.draft_indexer_metadata_builder = None
1123

1124
        if self.supports_mm_inputs:
1125
1126
1127
            # Even if the target model is multimodal, we can also use
            # text-only draft models
            try:
1128
                dummy_input_ids = torch.tensor([[1]], device=self.input_ids.device)
1129
                self.model.embed_input_ids(dummy_input_ids, multimodal_embeddings=None)
1130
1131
1132
            except (NotImplementedError, AttributeError, TypeError):
                logger.warning(
                    "Draft model does not support multimodal inputs, "
1133
1134
                    "falling back to text-only mode"
                )
1135
                self.supports_mm_inputs = False
1136

1137
1138
        if supports_multimodal(target_model):
            # handle multimodality
1139
1140
1141
            if self.get_model_name(target_model) in [
                "Qwen2_5_VLForConditionalGeneration",
                "Qwen3VLForConditionalGeneration",
1142
                "Qwen3VLMoeForConditionalGeneration",
1143
                "HunYuanVLForConditionalGeneration",
1144
                "GlmOcrForConditionalGeneration",
1145
            ]:
1146
                self.model.config.image_token_index = target_model.config.image_token_id
1147
1148
1149
1150
            elif self.get_model_name(target_model) == "PixtralForConditionalGeneration":
                self.model.config.image_token_index = (
                    target_model.config.vision_config.image_token_id
                )
1151
1152
            else:
                self.model.config.image_token_index = (
1153
1154
                    target_model.config.image_token_index
                )
1155
1156
1157
            target_language_model = target_model.get_language_model()
        else:
            target_language_model = target_model
1158

1159
        # share embed_tokens with the target model if needed
1160
        if get_pp_group().world_size == 1:
1161
            if hasattr(target_language_model.model, "embed_tokens"):
1162
                target_embed_tokens = target_language_model.model.embed_tokens
1163
            elif hasattr(target_language_model.model, "embedding"):
1164
1165
1166
                target_embed_tokens = target_language_model.model.embedding
            else:
                raise AttributeError(
1167
1168
                    "Target model does not have 'embed_tokens' or 'embedding' attribute"
                )
1169

1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
            share_embeddings = False
            if hasattr(self.model, "has_own_embed_tokens"):
                # EAGLE model
                if not self.model.has_own_embed_tokens:
                    share_embeddings = True
                    logger.info(
                        "Detected EAGLE model without its own embed_tokens in the"
                        " checkpoint. Sharing target model embedding weights with the"
                        " draft model."
                    )
                elif (
                    isinstance(target_embed_tokens.weight, torch.Tensor)
                    and isinstance(self.model.model.embed_tokens.weight, torch.Tensor)
1183
1184
1185
                    # TODO: Offload to CPU for comparison to avoid extra GPU memory
                    # usage in CI testing environments with limited GPU memory
                    and torch.equal(
1186
1187
                        target_embed_tokens.weight.cpu(),
                        self.model.model.embed_tokens.weight.cpu(),
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
                    )
                ):
                    share_embeddings = True
                    logger.info(
                        "Detected EAGLE model with embed_tokens identical to the target"
                        " model. Sharing target model embedding weights with the draft"
                        " model."
                    )
                else:
                    logger.info(
                        "Detected EAGLE model with distinct embed_tokens weights. "
                        "Keeping separate embedding weights from the target model."
                    )
1201
            else:
1202
1203
                # MTP model
                share_embeddings = True
1204
                logger.info(
1205
1206
                    "Detected MTP model. "
                    "Sharing target model embedding weights with the draft model."
1207
                )
1208
1209
1210
1211
1212

            if share_embeddings:
                if hasattr(self.model.model, "embed_tokens"):
                    del self.model.model.embed_tokens
                self.model.model.embed_tokens = target_embed_tokens
1213
        else:
1214
            logger.info(
1215
                "The draft model's vocab embedding will be loaded separately"
1216
1217
                " from the target model."
            )
1218
1219

        # share lm_head with the target model if needed
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
        share_lm_head = False
        if hasattr(self.model, "has_own_lm_head"):
            # EAGLE model
            if not self.model.has_own_lm_head:
                share_lm_head = True
                logger.info(
                    "Detected EAGLE model without its own lm_head in the checkpoint. "
                    "Sharing target model lm_head weights with the draft model."
                )
            elif (
                hasattr(target_language_model, "lm_head")
                and isinstance(target_language_model.lm_head.weight, torch.Tensor)
                and isinstance(self.model.lm_head.weight, torch.Tensor)
1233
1234
                # TODO: Offload to CPU for comparison to avoid extra GPU memory
                # usage in CI testing environments with limited GPU memory
1235
                and torch.equal(
1236
1237
                    target_language_model.lm_head.weight.cpu(),
                    self.model.lm_head.weight.cpu(),
1238
                )
1239
            ):
1240
                share_lm_head = True
1241
                logger.info(
1242
1243
                    "Detected EAGLE model with lm_head identical to the target model. "
                    "Sharing target model lm_head weights with the draft model."
1244
                )
1245
1246
            else:
                logger.info(
1247
1248
                    "Detected EAGLE model with distinct lm_head weights. "
                    "Keeping separate lm_head weights from the target model."
1249
                )
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
        else:
            # MTP model
            share_lm_head = True
            logger.info(
                "Detected MTP model. "
                "Sharing target model lm_head weights with the draft model."
            )

        if share_lm_head and hasattr(target_language_model, "lm_head"):
            if hasattr(self.model, "lm_head"):
                del self.model.lm_head
            self.model.lm_head = target_language_model.lm_head
1262

1263
1264
1265
1266
    @torch.inference_mode()
    def dummy_run(
        self,
        num_tokens: int,
1267
1268
        use_cudagraphs: bool = True,
        is_graph_capturing: bool = False,
1269
        slot_mappings: dict[str, torch.Tensor] | None = None,
1270
    ) -> None:
Rémi Delacourt's avatar
Rémi Delacourt committed
1271
1272
1273
1274
        # FIXME: when using tree-based specdec, adjust number of forward-passes
        # according to the depth of the tree.
        for fwd_idx in range(
            self.num_speculative_tokens if not is_graph_capturing else 1
1275
        ):
Rémi Delacourt's avatar
Rémi Delacourt committed
1276
1277
            if fwd_idx <= 1:
                num_tokens_dp_padded, num_tokens_across_dp = self._pad_batch_across_dp(
1278
                    num_tokens_unpadded=num_tokens, num_tokens_padded=num_tokens
Rémi Delacourt's avatar
Rémi Delacourt committed
1279
                )
1280
1281
1282
1283
1284
1285
1286
1287
                if use_cudagraphs:
                    cudagraph_runtime_mode, batch_desc = (
                        self.cudagraph_dispatcher.dispatch(num_tokens_dp_padded)
                    )
                    num_input_tokens = batch_desc.num_tokens
                else:
                    cudagraph_runtime_mode = CUDAGraphMode.NONE
                    num_input_tokens = num_tokens_dp_padded
Rémi Delacourt's avatar
Rémi Delacourt committed
1288
1289
                if num_tokens_across_dp is not None:
                    num_tokens_across_dp[self.dp_rank] = num_input_tokens
1290

1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
            # Make sure to use EAGLE's own buffer during cudagraph capture.
            if (
                self.attn_layer_names
                and slot_mappings is not None
                and self.attn_layer_names[0] in slot_mappings
            ):
                slot_mapping_dict = self._get_slot_mapping(num_input_tokens)
            else:
                slot_mapping_dict = slot_mappings or {}

Rémi Delacourt's avatar
Rémi Delacourt committed
1301
1302
1303
1304
1305
            with set_forward_context(
                None,
                self.vllm_config,
                num_tokens=num_input_tokens,
                num_tokens_across_dp=num_tokens_across_dp,
1306
                cudagraph_runtime_mode=cudagraph_runtime_mode,
1307
                slot_mapping=slot_mapping_dict,
Rémi Delacourt's avatar
Rémi Delacourt committed
1308
1309
1310
1311
1312
1313
1314
1315
            ):
                if self.supports_mm_inputs:
                    input_ids = None
                    inputs_embeds = self.inputs_embeds[:num_input_tokens]
                else:
                    input_ids = self.input_ids[:num_input_tokens]
                    inputs_embeds = None

1316
                kwargs = dict(
Rémi Delacourt's avatar
Rémi Delacourt committed
1317
1318
1319
1320
                    input_ids=input_ids,
                    positions=self._get_positions(num_input_tokens),
                    inputs_embeds=inputs_embeds,
                )
1321
1322
1323
                if self.pass_hidden_states_to_model:
                    kwargs["hidden_states"] = self.hidden_states[:num_input_tokens]
                self.model(**kwargs)
1324

1325
    def _get_attention_metadata_builder(self) -> AttentionMetadataBuilder:
1326
        """Find and return the attention metadata builders for EAGLE layers.
1327

1328
1329
        Returns:
            The metadata builders for EAGLE layers.
1330

1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
        Raises:
            AssertionError: If no metadata builders are found for EAGLE layers.
        """
        builder = None
        chosen_layer = self.attn_layer_names[0]

        for kv_cache_group in self.runner.attn_groups:
            for attn_group in kv_cache_group:
                if chosen_layer in attn_group.layer_names:
                    builder = attn_group.get_metadata_builder()
                    break
            if builder is not None:
                break

        assert builder is not None, (
1346
1347
            "Failed to find attention metadata builder for EAGLE layers."
        )
1348
1349
        return builder

1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
    def _get_eagle3_use_aux_hidden_state_from_config(self) -> bool:
        """
        Some eagle3 heads (e.g., nvidia/gpt-oss-120b-Eagle3-v2) do not use auxiliary
        hidden states and directly uses the last layer output just like eagle1.
        They might indicate this by setting "use_aux_hidden_state" to False
        inside the "eagle_config" dict of their hf_config.
        """
        if self.method != "eagle3":
            return False
        # Assume that eagle3 heads use aux hidden states by default
        use_aux_hidden_state = True
        eagle_config = getattr(self.draft_model_config.hf_config, "eagle_config", None)
        if eagle_config is not None:
            use_aux_hidden_state = eagle_config.get("use_aux_hidden_state", True)
        return use_aux_hidden_state

1366
    def validate_same_kv_cache_group(self, kv_cache_config: KVCacheConfig) -> None:
1367
        """
1368
1369
        Validate that all drafting layers belong to the same KVCacheGroup.
        Need this assumption to ensure all drafting layers can use the
1370
1371
1372
1373
1374
1375
1376
        same AttentionMetadata.
        May extend to multiple AttentionMetadata in the future.
        """
        kv_cache_groups: dict[str, int] = {}
        for id, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups):
            for layer_name in kv_cache_group.layer_names:
                kv_cache_groups[layer_name] = id
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
        assert (
            len(
                set(
                    [
                        kv_cache_groups[layer_name]
                        for layer_name in self.attn_layer_names
                    ]
                )
            )
            == 1
1387
        ), "All drafting layers should belong to the same kv cache group"
1388

Rémi Delacourt's avatar
Rémi Delacourt committed
1389
1390
1391
1392
1393
1394
    def _pad_batch_across_dp(
        self,
        num_tokens_unpadded: int,
        num_tokens_padded: int,
    ) -> tuple[int, torch.Tensor]:
        # TODO(Flechman): support DBO ubatching
1395
        should_ubatch, num_toks_across_dp, _ = coordinate_batch_across_dp(
Rémi Delacourt's avatar
Rémi Delacourt committed
1396
1397
1398
            num_tokens_unpadded=num_tokens_unpadded,
            parallel_config=self.vllm_config.parallel_config,
            allow_microbatching=False,
1399
1400
            allow_dp_padding=self.cudagraph_dispatcher.cudagraph_mode
            != CUDAGraphMode.NONE,
Rémi Delacourt's avatar
Rémi Delacourt committed
1401
1402
1403
1404
            num_tokens_padded=num_tokens_padded,
            uniform_decode=None,
            num_scheduled_tokens_per_request=None,
        )
1405
        assert not should_ubatch, "DBO ubatching not implemented for EAGLE"
Rémi Delacourt's avatar
Rémi Delacourt committed
1406
1407
1408
1409
1410
1411

        num_tokens_dp_padded = num_tokens_padded
        if num_toks_across_dp is not None:
            num_tokens_dp_padded = int(num_toks_across_dp[self.dp_rank].item())
        return num_tokens_dp_padded, num_toks_across_dp

1412

1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
class EagleProposer(SpecDecodeBaseProposer):
    def __init__(
        self,
        vllm_config: VllmConfig,
        device: torch.device,
        runner=None,
    ):
        super().__init__(
            vllm_config,
            device,
            pass_hidden_states_to_model=True,
            runner=runner,
        )


1428
1429
1430
1431
# NOTE(woosuk): Currently, the below code is not used and we always use argmax
# to sample the draft tokens. We will use this after we find a way to manage
# the draft prob tensor.
# Refer to https://github.com/vllm-project/vllm/pull/16899 for the details.
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
# FIXME(woosuk): The logic here is duplicated with the main sampling code.
# We should refactor this to reuse the same sampling implementation.
def compute_probs_and_sample_next_token(
    logits: torch.Tensor,
    sampling_metadata: SamplingMetadata,
) -> tuple[torch.Tensor, torch.Tensor]:
    if sampling_metadata.all_greedy:
        # For greedy requests, draft_probs is not used in rejection sampling.
        # Therefore, we can just return the logits.
        probs = logits
        next_token_ids = logits.argmax(dim=-1)
        return next_token_ids, probs

1445
1446
1447
1448
1449
1450
1451
1452
1453
    assert sampling_metadata.temperature is not None

    # Use epsilon comparison to detect greedy sampling (temperature ~ 0.0)
    # consistent with sampler.py's _SAMPLING_EPS threshold
    temperature = sampling_metadata.temperature
    # Avoid division by zero if there are greedy requests.
    if not sampling_metadata.all_random:
        is_greedy = temperature < _SAMPLING_EPS
        temperature = torch.where(is_greedy, 1.0, temperature)
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
    logits.div_(temperature.view(-1, 1))
    probs = logits.softmax(dim=-1, dtype=torch.float32)

    # NOTE(woosuk): Currently, we ignore most of the sampling parameters in
    # generating the draft tokens. We only use the temperature. While this
    # could degrade the acceptance rate, it does not affect the distribution
    # of the generated tokens after rejection sampling.

    # TODO(woosuk): Consider seeds.
    q = torch.empty_like(probs)
    q.exponential_()
1465
1466
1467
    # NOTE(woosuk): We shouldn't use `probs.div_(q)` because the draft_probs
    # will be used later for rejection sampling.
    next_token_ids = probs.div(q).argmax(dim=-1).view(-1)
1468
1469
    if not sampling_metadata.all_random:
        greedy_token_ids = probs.argmax(dim=-1)
1470
        next_token_ids = torch.where(is_greedy, greedy_token_ids, next_token_ids)
1471
    return next_token_ids, probs