eagle.py 76.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import ast
4
from importlib.util import find_spec
5
from typing import cast
6

7
import numpy as np
8
9
10
import torch
import torch.nn as nn

11
12
13
14
from vllm.config import (
    CUDAGraphMode,
    VllmConfig,
    get_layers_from_vllm_config,
15
    replace,
16
)
17
from vllm.distributed.parallel_state import get_pp_group
18
from vllm.forward_context import set_forward_context
19
from vllm.logger import init_logger
20
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
21
from vllm.model_executor.model_loader import get_model
22
from vllm.model_executor.models import supports_multimodal
23
from vllm.model_executor.models.deepseek_eagle3 import Eagle3DeepseekV2ForCausalLM
24
from vllm.model_executor.models.interfaces import SupportsMultiModal
25
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
26
from vllm.multimodal import MULTIMODAL_REGISTRY
27
from vllm.platforms import current_platform
28
from vllm.triton_utils import triton
29
from vllm.utils.platform_utils import is_pin_memory_available
30
from vllm.v1.attention.backend import CommonAttentionMetadata
31
from vllm.v1.attention.backends.registry import AttentionBackendEnum
32
33
34
35
from vllm.v1.attention.backends.tree_attn import (
    TreeAttentionMetadata,
    TreeAttentionMetadataBuilder,
)
36
from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata
37
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
38
from vllm.v1.kv_cache_interface import KVCacheConfig, UniformTypeKVCacheSpecs
39
from vllm.v1.sample.metadata import SamplingMetadata
40
from vllm.v1.sample.sampler import _SAMPLING_EPS
41
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
42
from vllm.v1.spec_decode.utils import (
43
44
45
    PADDING_SLOT_ID,
    compute_new_slot_mapping,
    copy_and_expand_eagle_inputs_kernel,
46
47
    eagle_prepare_inputs_padded_kernel,
    eagle_prepare_next_token_padded_kernel,
48
    eagle_step_update_slot_mapping_and_metadata,
49
    extend_all_queries_by_N,
50
)
51
from vllm.v1.utils import CpuGpuBuffer
Rémi Delacourt's avatar
Rémi Delacourt committed
52
from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
53
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
54
from vllm.v1.worker.utils import AttentionGroup
55

56
57
logger = init_logger(__name__)

58

59
class SpecDecodeBaseProposer:
60
61
62
63
    def __init__(
        self,
        vllm_config: VllmConfig,
        device: torch.device,
64
        pass_hidden_states_to_model: bool,
Jiayi Yao's avatar
Jiayi Yao committed
65
        runner=None,
66
67
    ):
        self.vllm_config = vllm_config
68
        assert vllm_config.speculative_config is not None
69
70
71
        self.speculative_config = vllm_config.speculative_config
        self.draft_model_config = self.speculative_config.draft_model_config
        self.method = self.speculative_config.method
72
        self.pass_hidden_states_to_model = pass_hidden_states_to_model
73

74
        self.device = device
75
        self.dtype = vllm_config.model_config.dtype
76
        self.max_model_len = vllm_config.model_config.max_model_len
Rémi Delacourt's avatar
Rémi Delacourt committed
77
        self.dp_rank = vllm_config.parallel_config.data_parallel_rank
78
        self.num_speculative_tokens = self.speculative_config.num_speculative_tokens
79

80
81
82
83
        # We need to get the hidden size from the draft model config because
        # the draft model's hidden size can be different from the target model's
        # hidden size (e.g., Llama 3.3 70B).
        self.hidden_size = self.draft_model_config.get_hidden_size()
84
        self.inputs_embeds_size = self.draft_model_config.get_inputs_embeds_size()
85

86
87
88
89
90
91
92
93
94
95
96
97
98
99
        # Unifying eagle, draft model, and parallel drafting support
        self.parallel_drafting: bool = self.speculative_config.parallel_drafting
        self.extra_slots_per_request = (
            1 if not self.parallel_drafting else self.num_speculative_tokens
        )
        self.net_num_new_slots_per_request = self.extra_slots_per_request - (
            1 if self.pass_hidden_states_to_model else 0
        )
        self.needs_extra_input_slots = self.net_num_new_slots_per_request > 0

        self.parallel_drafting_token_id: int = 0
        self.parallel_drafting_hidden_state_tensor: torch.Tensor | None = None
        if self.parallel_drafting:
            self._init_parallel_drafting_params()
100
101
102
        self.use_local_argmax_reduction: bool = (
            self.speculative_config.use_local_argmax_reduction
        )
103
104

        max_batch_size = vllm_config.scheduler_config.max_num_seqs
105
        self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
106
107
        self.token_arange_np = np.arange(self.max_num_tokens)

108
109
110
        # Multi-modal data support
        self.mm_registry = MULTIMODAL_REGISTRY
        self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
111
            vllm_config.model_config
112
        )
113

114
115
        self.draft_attn_groups: list[AttentionGroup] = []
        self.kv_cache_gid: int = -1
116
117
118
        self.eagle3_use_aux_hidden_state: bool = (
            self._get_eagle3_use_aux_hidden_state_from_config()
        )
119

120
        self.compilation_config = self.vllm_config.compilation_config
121
122
123
124
125
126

        # Cudagraph dispatcher for PIECEWISE-only dispatching in eagle.
        # Keys are initialized later via initialize_cudagraph_keys() called from
        # gpu_model_runner._check_and_update_cudagraph_mode after
        # adjust_cudagraph_sizes_for_spec_decode is called.
        self.cudagraph_dispatcher = CudagraphDispatcher(self.vllm_config)
127

128
        # persistent buffers for cuda graph
129
130
131
        self.input_ids = torch.zeros(
            self.max_num_tokens, dtype=torch.int32, device=device
        )
132
133
134
        # Use draft model's M-RoPE setting, not target model's
        # Draft models may be text-only even if target is multimodal
        self.uses_mrope = self.draft_model_config.uses_mrope
135
136
        self.uses_xdrope_dim = self.vllm_config.model_config.uses_xdrope_dim
        self.draft_uses_xdrope_dim = self.draft_model_config.uses_xdrope_dim
137
        if self.uses_mrope:
138
139
140
141
142
143
144
145
146
147
            # NOTE: `mrope_positions` is implemented with one additional dummy
            # position on purpose to make it non-contiguous so that it can work
            # with torch compile.
            # See detailed explanation in https://github.com/vllm-project/vllm/pull/12128#discussion_r1926431923

            # NOTE: When M-RoPE is enabled, position ids are 3D regardless of
            # the modality of inputs. For text-only inputs, each dimension has
            # identical position IDs, making M-RoPE functionally equivalent to
            # 1D-RoPE.
            # See page 5 of https://arxiv.org/abs/2409.12191
148
            self.mrope_positions = torch.zeros(
149
                (3, self.max_num_tokens + 1), dtype=torch.int64, device=device
150
            )
151
152
153
154
155
156
        elif self.uses_xdrope_dim > 0 and self.draft_uses_xdrope_dim > 0:
            self.xdrope_positions = torch.zeros(
                (self.uses_xdrope_dim, self.max_num_tokens + 1),
                dtype=torch.int64,
                device=device,
            )
157
158
        else:
            # RoPE need (max_num_tokens,)
159
160
161
            self.positions = torch.zeros(
                self.max_num_tokens, dtype=torch.int64, device=device
            )
162
        self.hidden_states = torch.zeros(
163
164
            (self.max_num_tokens, self.hidden_size), dtype=self.dtype, device=device
        )
165

166
167
168
        # Will be set when we initialize the attention backend
        self.block_size: int = -1

169
170
171
        # We need +1 here because the arange is used to set query_start_loc,
        # which has one more element than batch_size.
        max_num_slots_for_arange = max(max_batch_size + 1, self.max_num_tokens)
172
173
174
        self.arange = torch.arange(
            max_num_slots_for_arange, device=device, dtype=torch.int32
        )
175

176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
        if self.needs_extra_input_slots:
            self._raise_if_padded_drafter_batch_disabled()
            self._raise_if_multimodal()
            self._raise_if_mrope()

        self.is_rejected_token_mask: torch.Tensor | None = None
        self.is_masked_token_mask: torch.Tensor | None = None
        if self.needs_extra_input_slots:
            # For draft models and parallel drafting, we need to keep track of
            # which tokens are rejected to update the slot mapping with padding slots.
            self.is_rejected_token_mask = torch.zeros(
                (self.max_num_tokens,), dtype=torch.bool, device=device
            )
            # For parallel drafting, we also need to keep track of which tokens
            # are parallel-padding tokens used to sample at later positions.
            # We populate this tensor even when using draft models for simplicity.
            self.is_masked_token_mask = torch.zeros(
                (self.max_num_tokens,), dtype=torch.bool, device=device
            )

196
        self.inputs_embeds = torch.zeros(
197
198
199
            (self.max_num_tokens, self.inputs_embeds_size),
            dtype=self.dtype,
            device=device,
200
        )
201

202
203
204
205
206
        self.backup_next_token_ids = CpuGpuBuffer(
            max_batch_size,
            dtype=torch.int32,
            pin_memory=is_pin_memory_available(),
            device=device,
207
208
            with_numpy=True,
        )
209

210
211
212
213
        self._slot_mapping_buffer = torch.zeros(
            self.max_num_tokens, dtype=torch.int64, device=device
        )

214
        # Determine allowed attention backends once during initialization.
215
        self.allowed_attn_types: tuple | None = None
216
        if current_platform.is_rocm():
217
218
219
            from vllm.v1.attention.backends.mla.rocm_aiter_mla_sparse import (
                ROCMAiterMLASparseMetadata,
            )
220
221
222
223
224
            from vllm.v1.attention.backends.rocm_attn import RocmAttentionMetadata

            rocm_types = [
                TritonAttentionMetadata,
                RocmAttentionMetadata,
225
                ROCMAiterMLASparseMetadata,
226
            ]
227
            # ROCM_AITER_FA is an optional backend
228
229
230
231
232
233
234
            # We check is_enabled() here to avoid importing the backend module during
            # auto-discovery when VLLM_ROCM_USE_AITER=0, which would trigger aiter
            # import and JIT compilation warnings. Explicit backend selection via
            # attention_config still works because the backend module is loaded
            # directly when selected, not through this auto-discovery path.
            # Check if backend module exists to allow explicit selection
            if find_spec(
235
236
                AttentionBackendEnum.ROCM_AITER_FA.get_path(include_classname=False)
            ):
237
                from vllm.v1.attention.backends.rocm_aiter_fa import (
238
239
240
                    AiterFlashAttentionMetadata,
                )

241
                rocm_types.append(AiterFlashAttentionMetadata)
242
243

            # TRITON_MLA backend support for MLA models (e.g., DeepSeek)
244
245
246
            from vllm.model_executor.layers.attention.mla_attention import (
                MLACommonMetadata,
            )
247
248
249

            rocm_types.append(MLACommonMetadata)

250
251
252
253
254
            # FlexAttention backend support
            from vllm.v1.attention.backends.flex_attention import FlexAttentionMetadata

            rocm_types.append(FlexAttentionMetadata)

255
256
            self.allowed_attn_types = tuple(rocm_types)

257
258
        # Parse the speculative token tree.
        spec_token_tree = self.speculative_config.speculative_token_tree
259
        assert spec_token_tree is not None
260
        self.tree_choices: list[tuple[int, ...]] = ast.literal_eval(spec_token_tree)
261
262
263
264
265
266
267
268
        tree_depth = len(self.tree_choices[-1])
        # Precompute per-level properties of the tree.
        num_drafts_per_level = [0] * tree_depth
        for node in self.tree_choices:
            num_drafts_per_level[len(node) - 1] += 1
        self.cu_drafts_per_level = [num_drafts_per_level[0]]
        self.child_drafts_per_level = [num_drafts_per_level[0]]
        for level in range(1, tree_depth):
269
270
271
272
273
274
            self.cu_drafts_per_level.append(
                self.cu_drafts_per_level[-1] + num_drafts_per_level[level]
            )
            self.child_drafts_per_level.append(
                num_drafts_per_level[level] // num_drafts_per_level[level - 1]
            )
275
276
        # Precompute draft position offsets in flattened tree.
        self.tree_draft_pos_offsets = torch.arange(
277
            1, len(self.tree_choices) + 1, device=device, dtype=torch.int32
278
279
        ).repeat(max_batch_size, 1)

280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
    def _raise_if_padded_drafter_batch_disabled(self):
        if self.speculative_config.disable_padded_drafter_batch:
            raise NotImplementedError(
                "Speculative Decoding with draft models or parallel drafting only "
                "supports padded drafter batch. Please unset "
                "disable_padded_drafter_batch in the speculative_config."
            )

    def _raise_if_multimodal(self):
        if self.supports_mm_inputs:
            raise NotImplementedError(
                "Speculative Decoding with draft models or parallel drafting "
                "does not support multimodal models yet"
            )

    def _raise_if_mrope(self):
        if self.draft_model_config.uses_mrope:
            raise NotImplementedError(
                "Speculative Decoding with draft models or parallel drafting "
                "does not support M-RoPE yet"
            )

    def _init_parallel_drafting_params(self):
        # For parallel drafting, we need the token ID to use for masked slots
        # And for EAGLE + parallel drafting, we need the hidden state tensor to use
        # for those masked slots.

        model_hf_config = self.draft_model_config.hf_config
        if hasattr(model_hf_config, "pard_token"):
            self.parallel_drafting_token_id = model_hf_config.pard_token
        elif hasattr(model_hf_config, "ptd_token_id"):
            self.parallel_drafting_token_id = model_hf_config.ptd_token_id
        else:
            raise ValueError(
                "For parallel drafting, the draft model config must have "
                "`pard_token` or `ptd_token_id` specified in its config.json."
            )

        if self.pass_hidden_states_to_model:
            self.parallel_drafting_hidden_state_tensor = torch.empty(
                self.hidden_size, dtype=self.dtype, device=self.device
            )

323
324
325
    def _get_positions(self, num_tokens: int):
        if self.uses_mrope:
            return self.mrope_positions[:, :num_tokens]
326
327
        if self.uses_xdrope_dim > 0 and self.draft_uses_xdrope_dim > 0:
            return self.xdrope_positions[:, :num_tokens]
328
329
330
331
332
        return self.positions[:num_tokens]

    def _set_positions(self, num_tokens: int, positions: torch.Tensor):
        if self.uses_mrope:
            self.mrope_positions[:, :num_tokens] = positions
333
334
        elif self.uses_xdrope_dim > 0 and self.draft_uses_xdrope_dim > 0:
            self.xdrope_positions[:, :num_tokens] = positions
335
        else:
336
337
338
339
340
            # Convert M-RoPE positions if target model uses M-RoPE
            # but draft doesn't, For text inputs, all M-RoPE
            # dimensions are identical
            if self.vllm_config.model_config.uses_mrope:
                positions = positions[0]
341
342
            self.positions[:num_tokens] = positions

343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
    def _get_slot_mapping(
        self,
        num_tokens: int,
        slot_mapping: torch.Tensor | None = None,
    ) -> dict[str, torch.Tensor]:
        """Return slot_mapping dict for EAGLE layers.

        If slot_mapping is provided, copies it into the buffer first.
        """
        if slot_mapping is not None:
            num_actual = slot_mapping.shape[0]
            self._slot_mapping_buffer[:num_actual].copy_(slot_mapping)
            if num_tokens > num_actual:
                self._slot_mapping_buffer[num_actual:num_tokens].fill_(PADDING_SLOT_ID)

        view = self._slot_mapping_buffer[:num_tokens]
359
        return {name: view for name in self._draft_attn_layer_names}
360

361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
    def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode) -> None:
        """Initialize cudagraph dispatcher keys for eagle.

        Eagle only supports PIECEWISE cudagraphs (via mixed_mode).
        This should be called after adjust_cudagraph_sizes_for_spec_decode.
        """
        if (
            not self.speculative_config.enforce_eager
            and cudagraph_mode.mixed_mode()
            in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL]
        ):
            eagle_cudagraph_mode = CUDAGraphMode.PIECEWISE
        else:
            eagle_cudagraph_mode = CUDAGraphMode.NONE

        self.cudagraph_dispatcher.initialize_cudagraph_keys(eagle_cudagraph_mode)

378
379
380
381
382
383
    def _greedy_sample(self, hidden_states: torch.Tensor) -> torch.Tensor:
        """Greedy-sample draft tokens from hidden states."""
        if self.use_local_argmax_reduction:
            return self.model.get_top_tokens(hidden_states)
        return self.model.compute_logits(hidden_states).argmax(dim=-1)

384
385
386
387
    def propose(
        self,
        # [num_tokens]
        target_token_ids: torch.Tensor,
388
        # [num_tokens] or [3, num_tokens] when M-RoPE is enabled
389
390
391
392
393
        target_positions: torch.Tensor,
        # [num_tokens, hidden_size]
        target_hidden_states: torch.Tensor,
        # [batch_size]
        next_token_ids: torch.Tensor,
394
        token_indices_to_sample: torch.Tensor | None,
395
        common_attn_metadata: CommonAttentionMetadata,
396
        sampling_metadata: SamplingMetadata,
397
        mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None,
398
        num_rejected_tokens_gpu: torch.Tensor | None = None,
399
400
401
        slot_mappings: dict[str, torch.Tensor]
        | list[dict[str, torch.Tensor]]
        | None = None,
402
    ) -> torch.Tensor:
403
        batch_size = common_attn_metadata.batch_size()
404

405
        if self.method == "eagle3":
406
407
408
            assert isinstance(
                self.model, (Eagle3LlamaForCausalLM, Eagle3DeepseekV2ForCausalLM)
            )
409
            target_hidden_states = self.model.combine_hidden_states(
410
411
                target_hidden_states
            )
412
            assert target_hidden_states.shape[-1] == self.hidden_size
413

414
        num_tokens, token_indices_to_sample, common_attn_metadata = (
415
416
417
418
            self.set_inputs_first_pass(
                target_token_ids=target_token_ids,
                next_token_ids=next_token_ids,
                target_positions=target_positions,
419
420
                target_hidden_states=target_hidden_states,
                token_indices_to_sample=token_indices_to_sample,
421
422
423
424
                cad=common_attn_metadata,
                num_rejected_tokens_gpu=num_rejected_tokens_gpu,
            )
        )
425

426
427
428
429
        per_layer_attn_metadata: dict[str, object] = {}
        for attn_group in self.draft_attn_groups:
            attn_metadata = attn_group.get_metadata_builder().build_for_drafting(
                common_attn_metadata=common_attn_metadata, draft_index=0
430
            )
431
432
            for layer_name in attn_group.layer_names:
                per_layer_attn_metadata[layer_name] = attn_metadata
433

434
435
        cudagraph_runtime_mode, num_input_tokens, num_tokens_across_dp = (
            self._determine_batch_execution_and_padding(num_tokens)
Rémi Delacourt's avatar
Rémi Delacourt committed
436
437
        )

438
439
440
        if self.supports_mm_inputs:
            mm_embeds, is_mm_embed = mm_embed_inputs or (None, None)

441
            self.inputs_embeds[:num_tokens] = self.model.embed_input_ids(
442
443
444
                self.input_ids[:num_tokens],
                multimodal_embeddings=mm_embeds,
                is_multimodal=is_mm_embed,
445
            )
446

447
            input_ids = None
448
            inputs_embeds = self.inputs_embeds[:num_input_tokens]
449
450
        else:
            input_ids = self.input_ids[:num_input_tokens]
451
            inputs_embeds = None
452

453
454
455
456
457
458
459
460
        model_kwargs = {
            "input_ids": input_ids,
            "positions": self._get_positions(num_input_tokens),
            "inputs_embeds": inputs_embeds,
        }
        if self.pass_hidden_states_to_model:
            model_kwargs["hidden_states"] = self.hidden_states[:num_input_tokens]

461
        with set_forward_context(
462
463
464
            per_layer_attn_metadata,
            self.vllm_config,
            num_tokens=num_input_tokens,
Rémi Delacourt's avatar
Rémi Delacourt committed
465
            num_tokens_across_dp=num_tokens_across_dp,
466
            cudagraph_runtime_mode=cudagraph_runtime_mode,
467
468
469
            slot_mapping=self._get_slot_mapping(
                num_input_tokens, common_attn_metadata.slot_mapping
            ),
470
        ):
471
472
            ret_hidden_states = self.model(**model_kwargs)
            if not self.model_returns_tuple():
Jiayi Yao's avatar
Jiayi Yao committed
473
                last_hidden_states = ret_hidden_states
474
                hidden_states = last_hidden_states
Jiayi Yao's avatar
Jiayi Yao committed
475
476
            else:
                last_hidden_states, hidden_states = ret_hidden_states
477

478
        sample_hidden_states = last_hidden_states[token_indices_to_sample]
479
480

        # Early exit if there is only one draft token to be generated.
481
        if self.num_speculative_tokens == 1 or self.parallel_drafting:
482
            draft_token_ids = self._greedy_sample(sample_hidden_states)
483
            return draft_token_ids.view(-1, self.num_speculative_tokens)
484

485
        if self.uses_mrope:
486
            positions = self.mrope_positions[:, token_indices_to_sample]
487
        else:
488
            positions = self.positions[token_indices_to_sample]
489
        hidden_states = hidden_states[token_indices_to_sample]
490
491

        if isinstance(attn_metadata, TreeAttentionMetadata):
492
493
            # Draft using tree attention - requires full logits for top-k
            logits = self.model.compute_logits(sample_hidden_states)
494
495
496
497
498
499
            draft_token_ids_list = self.propose_tree(
                batch_size=batch_size,
                logits=logits,
                positions=positions,
                hidden_states=hidden_states,
                common_attn_metadata=common_attn_metadata,
500
                slot_mappings=slot_mappings,
501
502
503
504
            )
            # [batch_size, num_tree_tokens]
            return torch.cat(draft_token_ids_list, dim=1)

505
        draft_token_ids = self._greedy_sample(sample_hidden_states)
506

507
508
509
        if self.allowed_attn_types is not None and not isinstance(
            attn_metadata, self.allowed_attn_types
        ):
510
511
512
513
            raise ValueError(
                f"Unsupported attention metadata type for speculative "
                "decoding with num_speculative_tokens > 1: "
                f"{type(attn_metadata)}. Supported types are: "
514
515
                f"{self.allowed_attn_types}"
            )
516

517
518
519
        # Generate the remaining draft tokens.
        draft_token_ids_list = [draft_token_ids]

520
521
        cudagraph_runtime_mode, input_batch_size, batch_size_across_dp = (
            self._determine_batch_execution_and_padding(batch_size)
522
        )
523
524
525

        common_attn_metadata.num_actual_tokens = batch_size
        common_attn_metadata.max_query_len = 1
526
        common_attn_metadata.query_start_loc = self.arange[: batch_size + 1]
527
        common_attn_metadata.query_start_loc_cpu = torch.from_numpy(
528
529
            self.token_arange_np[: batch_size + 1]
        ).clone()
530
531
532
533
534
535
536
537
538
539
540

        # In padded drafter batch, we need to adjust the sequence lengths
        # to remove the "padding" (i.e. rejected tokens).
        # Only apply this adjustment when we have rejected tokens
        # (i.e., not the first proposal).
        if self.num_speculative_tokens > 1 and num_rejected_tokens_gpu is not None:
            common_attn_metadata.seq_lens -= num_rejected_tokens_gpu
            # Invalidate the CPU-side shadows to avoid H<>D sync.
            common_attn_metadata._seq_lens_cpu = None
            common_attn_metadata._num_computed_tokens_cpu = None

541
542
        block_size = self.block_size
        assert block_size > 0, "block_size has not been initialized."
543
        for token_index in range(self.num_speculative_tokens - 1):
544
            # Update the inputs.
545
546
547
            # cast to int32 is crucial when eagle model is compiled.
            # tensor.argmax() returns int64 by default.
            input_ids = draft_token_ids_list[-1].int()
548
549
550
551
            # Use fused kernel for slot mapping and metadata updates.
            # Write clamped positions directly into the positions buffer to
            # avoid an extra D2D copy for the common (non-mrope) case.
            positions_1d = positions[0] if self.uses_mrope else positions
552
            if self.uses_mrope:
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
                out_pos = self.mrope_positions[0, :batch_size]
            elif self.uses_xdrope_dim > 0 and self.draft_uses_xdrope_dim > 0:
                out_pos = self.xdrope_positions[0, :batch_size]
            else:
                out_pos = self.positions[:batch_size]
            eagle_step_update_slot_mapping_and_metadata(
                positions_1d=positions_1d,
                block_table_tensor=common_attn_metadata.block_table_tensor,
                seq_lens=common_attn_metadata.seq_lens,
                block_size=block_size,
                max_model_len=self.max_model_len,
                out_clamped_positions=out_pos,
                out_slot_mapping=self._slot_mapping_buffer[:input_batch_size],
                input_batch_size=input_batch_size,
            )
            common_attn_metadata.slot_mapping = self._slot_mapping_buffer[:batch_size]
            if self.uses_mrope:
                self.mrope_positions[1:, :batch_size] = self.mrope_positions[
                    0, :batch_size
                ]
                positions = self.mrope_positions[:, :batch_size]
            elif self.uses_xdrope_dim > 0 and self.draft_uses_xdrope_dim > 0:
                self.xdrope_positions[1:, :batch_size] = self.xdrope_positions[
                    0, :batch_size
                ]
                positions = self.xdrope_positions[0, :batch_size]
579
            else:
580
                positions = self.positions[:batch_size]
581
582
583
584
585
586
            # Increment the maximum sequence length. We increment max_seq_len
            # unconditionally even though some seq_lens may have been capped above,
            # as max_seq_len serves as an upper bound for sequence lengths.
            common_attn_metadata.max_seq_len = min(
                common_attn_metadata.max_seq_len + 1, self.max_model_len
            )
587

588
589
590
591
592
593
            # Also update the CPU-side shadow; NOTE: this is hacky and should be
            # removed in when common_attn_metadata.seq_lens_cpu is deprecated.
            if common_attn_metadata._seq_lens_cpu is not None:
                common_attn_metadata._seq_lens_cpu += 1
            if common_attn_metadata._num_computed_tokens_cpu is not None:
                common_attn_metadata._num_computed_tokens_cpu += 1
594

595
            # Rebuild attention metadata
596
597
598
599
600
601
602
            for attn_group in self.draft_attn_groups:
                attn_metadata = attn_group.get_metadata_builder().build_for_drafting(
                    common_attn_metadata=common_attn_metadata,
                    draft_index=token_index + 1,
                )
                for layer_name in attn_group.layer_names:
                    per_layer_attn_metadata[layer_name] = attn_metadata
603

604
605
            # copy inputs to buffer for cudagraph
            self.input_ids[:batch_size] = input_ids
606
            self.hidden_states[:batch_size] = hidden_states
607
            if self.supports_mm_inputs:
608
                self.inputs_embeds[:batch_size] = self.model.embed_input_ids(input_ids)
609

610
                input_ids = None
611
                inputs_embeds = self.inputs_embeds[:input_batch_size]
612
613
            else:
                input_ids = self.input_ids[:input_batch_size]
614
                inputs_embeds = None
615

616
            # Run the model.
617
618
619
620
621
622
623
624
            model_kwargs = {
                "input_ids": input_ids,
                "positions": self._get_positions(input_batch_size),
                "inputs_embeds": inputs_embeds,
            }
            if self.pass_hidden_states_to_model:
                model_kwargs["hidden_states"] = self.hidden_states[:input_batch_size]

625
            with set_forward_context(
626
627
628
                per_layer_attn_metadata,
                self.vllm_config,
                num_tokens=input_batch_size,
Rémi Delacourt's avatar
Rémi Delacourt committed
629
                num_tokens_across_dp=batch_size_across_dp,
630
                cudagraph_runtime_mode=cudagraph_runtime_mode,
631
                slot_mapping=self._get_slot_mapping(input_batch_size),
632
            ):
633
634
                ret_hidden_states = self.model(**model_kwargs)
                if not self.model_returns_tuple():
635
636
637
638
                    last_hidden_states = ret_hidden_states
                    hidden_states = ret_hidden_states
                else:
                    last_hidden_states, hidden_states = ret_hidden_states
639

640
            hidden_states = hidden_states[:batch_size]
641
            draft_token_ids = self._greedy_sample(last_hidden_states[:batch_size])
642
643
644
645
            draft_token_ids_list.append(draft_token_ids)

        # [batch_size, num_speculative_tokens]
        draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
646
        return draft_token_ids
647

648
649
650
651
652
    def set_inputs_first_pass(
        self,
        target_token_ids: torch.Tensor,
        next_token_ids: torch.Tensor,
        target_positions: torch.Tensor,
653
654
        target_hidden_states: torch.Tensor,
        token_indices_to_sample: torch.Tensor | None,
655
656
657
        cad: CommonAttentionMetadata,
        num_rejected_tokens_gpu: torch.Tensor | None,
    ) -> tuple[int, torch.Tensor, CommonAttentionMetadata]:
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
        if not self.needs_extra_input_slots:
            # Default EAGLE pathway: no reshaping of input tensors needed.
            # Simply rotate the input ids and leave the positions unchanged,
            # Inserting the next token ids at the last slot in each request.
            if token_indices_to_sample is None:
                token_indices_to_sample = cad.query_start_loc[1:] - 1

            num_tokens = target_token_ids.shape[0]
            # Shift the input ids by one token.
            # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
            self.input_ids[: num_tokens - 1] = target_token_ids[1:]
            # Replace the last token with the next token.
            # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
            self.input_ids[token_indices_to_sample] = next_token_ids

            # copy inputs to buffer for cudagraph
            if self.uses_xdrope_dim > 0 and self.draft_uses_xdrope_dim == 0:
                target_positions = target_positions[0]
            self._set_positions(num_tokens, target_positions)

            self.hidden_states[:num_tokens] = target_hidden_states

            return num_tokens, token_indices_to_sample, cad
        else:
            assert self.is_rejected_token_mask is not None
            assert self.is_masked_token_mask is not None
            # 1.
            # Call a custom triton kernel to copy input_ids and positions
            # into the correct slots in the preallocated buffers self.input_ids,
            # self.positions.
            batch_size = cad.batch_size()
            # Since we might have to copy a lot of data for prefills, we select the
            # block size based on the max query length and limit to max 256 slots/block.
            max_num_tokens_per_request = (
                cad.max_query_len + self.net_num_new_slots_per_request
            )
            BLOCK_SIZE_TOKENS = min(
                256, triton.next_power_of_2(max_num_tokens_per_request)
            )
            num_blocks = (
                max_num_tokens_per_request + BLOCK_SIZE_TOKENS - 1
            ) // BLOCK_SIZE_TOKENS
            total_num_input_tokens = target_token_ids.shape[0]
            total_num_output_tokens = total_num_input_tokens + (
                self.net_num_new_slots_per_request * batch_size
            )
704

705
706
707
708
709
            token_indices_to_sample = torch.empty(
                batch_size * self.extra_slots_per_request,
                dtype=torch.int32,
                device=self.device,
            )
710

711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
            # Destination indices to write target_hidden_states into drafting buffer.
            out_hidden_state_mapping = torch.empty(
                total_num_input_tokens, dtype=torch.int32, device=self.device
            )

            # Kernel grid: one program per request (row)
            grid = (batch_size, num_blocks)
            query_start_loc = cad.query_start_loc
            query_end_loc = cad.query_start_loc[1:] - 1
            if num_rejected_tokens_gpu is not None:
                query_end_loc = query_end_loc - num_rejected_tokens_gpu
            copy_and_expand_eagle_inputs_kernel[grid](
                # (Padded) Inputs from the target model
                target_token_ids_ptr=target_token_ids,
                target_positions_ptr=target_positions,
                next_token_ids_ptr=next_token_ids,  # sampled tokens, one per request
                # Outputs to the drafting buffers
                out_input_ids_ptr=self.input_ids,
                out_positions_ptr=self.positions,  # Doesn't support mrope for now
                out_is_rejected_token_mask_ptr=self.is_rejected_token_mask,
                out_is_masked_token_mask_ptr=self.is_masked_token_mask,
                out_new_token_indices_ptr=token_indices_to_sample,
                out_hidden_state_mapping_ptr=out_hidden_state_mapping,
                # Input metadata
                query_start_loc_ptr=query_start_loc,
                query_end_loc_ptr=query_end_loc,
                padding_token_id=0,
                parallel_drafting_token_id=self.parallel_drafting_token_id,
                # Sizing info
                # Note that we can deduce batch_size for free from the grid size
                total_input_tokens=total_num_input_tokens,
                num_padding_slots_per_request=self.extra_slots_per_request,
                shift_input_ids=self.pass_hidden_states_to_model,
                BLOCK_SIZE_TOKENS=BLOCK_SIZE_TOKENS,
            )
            if self.pass_hidden_states_to_model:
                assert self.parallel_drafting_hidden_state_tensor is not None
                self.hidden_states[out_hidden_state_mapping] = target_hidden_states
                # Use torch.where to avoid DtoH sync from boolean indexing
                mask = self.is_masked_token_mask[:total_num_output_tokens]
                torch.where(
                    mask.unsqueeze(1),
                    self.parallel_drafting_hidden_state_tensor,
                    self.hidden_states[:total_num_output_tokens],
                    out=self.hidden_states[:total_num_output_tokens],
                )

            # 2.
            # Recompute the slot mapping based on the new positions and
            # rejection mask.
761
            assert self.block_size > 0, "block_size has not been initialized."
762
763
764
765
766
767
            new_slot_mapping = compute_new_slot_mapping(
                cad=cad,
                new_positions=self.positions[:total_num_output_tokens],
                is_rejected_token_mask=self.is_rejected_token_mask[
                    :total_num_output_tokens
                ],
768
                block_size=self.block_size,
769
770
771
772
773
774
775
776
777
778
779
                num_new_tokens=self.net_num_new_slots_per_request,
                max_model_len=self.max_model_len,
            )

            # 3. Update the common attention metadata with the new (meta)data
            new_cad = extend_all_queries_by_N(
                cad,
                N=self.net_num_new_slots_per_request,
                arange=self.arange,
                new_slot_mapping=new_slot_mapping,
            )
780

781
            return total_num_output_tokens, token_indices_to_sample, new_cad
782
783
784
785

    def model_returns_tuple(self) -> bool:
        return self.method not in ("mtp", "draft_model")

786
    def prepare_next_token_ids_cpu(
787
        self,
788
        sampled_token_ids: list[list[int]],
789
790
791
792
        requests: dict[str, CachedRequestState],
        gpu_input_batch: InputBatch,
        num_scheduled_tokens: dict[str, int],
    ) -> torch.Tensor:
793
794
795
796
797
798
799
800
801
802
        """
        This function is used to prepare the inputs for speculative decoding.
        It calculates the next token ids for each request based on the sampled
        token ids from the CPU. If a request has no sampled token ids (e.g.,
        during the initial decoding steps), it falls back to using the request
        state to get the next token id.
        """
        req_ids = gpu_input_batch.req_ids
        next_token_ids: list[int] = []
        for i, token_ids in enumerate(sampled_token_ids):
803
            if token_ids:
804
805
806
807
808
809
810
                # Common case.
                next_token_id = token_ids[-1]
            else:
                # Partial prefill (rare case).
                # Get the next token id from the request state.
                req_id = req_ids[i]
                req_state = requests[req_id]
811
                seq_len = req_state.num_computed_tokens + num_scheduled_tokens[req_id]
812
813
                next_token_id = req_state.get_token_id(seq_len)
            next_token_ids.append(next_token_id)
814
        next_token_ids = torch.tensor(
815
816
            next_token_ids, dtype=torch.int32, device=self.input_ids.device
        )
817
        return next_token_ids
818

819
820
821
822
823
    def prepare_next_token_ids_padded(
        self,
        sampled_token_ids: torch.Tensor,
        requests: dict[str, CachedRequestState],
        gpu_input_batch: InputBatch,
824
        discard_request_mask: torch.Tensor,
825
    ) -> tuple[torch.Tensor, torch.Tensor]:
826
827
828
829
        """
        This function is used to prepare the inputs for speculative decoding.
        It calculates the next token ids and the number of valid sampled tokens
        for each request, considering the "discarded" requests whose next token
830
831
        is not sampled and comes from `request.get_token_id()` instead. This is denoted
        the "backup" token id. It also counts rejected tokens via `sampled_token_ids`.
832
833
834
        """
        # Precompute get_token_id for when there is no valid next token
        num_reqs = gpu_input_batch.num_reqs
835
        seq_lens_list = (gpu_input_batch.num_tokens_no_spec[:num_reqs] - 1).tolist()
836
837
        self.backup_next_token_ids.np[:num_reqs] = np.array(
            [
838
                requests[gpu_input_batch.req_ids[i]].get_token_id(seq_lens_list[i])
839
                for i in range(num_reqs)
840
841
            ],
            dtype=np.int32,
842
        )
843
        self.backup_next_token_ids.copy_to_gpu(num_reqs)
844
        backup_tokens_gpu = self.backup_next_token_ids.gpu
845

846
847
        batch_size, num_tokens = sampled_token_ids.shape
        device = sampled_token_ids.device
848

849
850
        assert discard_request_mask.dtype == torch.bool
        assert backup_tokens_gpu.dtype == torch.int32
851

852
853
        next_token_ids = torch.empty(batch_size, dtype=torch.int32, device=device)
        valid_sampled_tokens_count = next_token_ids.new_empty(batch_size)
854

855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
        # Kernel grid: one program per request (row)
        grid = (batch_size,)

        # Find the next power of 2 for block sizes
        BLOCK_SIZE_TOKENS = triton.next_power_of_2(num_tokens)
        eagle_prepare_next_token_padded_kernel[grid](
            sampled_token_ids,
            discard_request_mask,
            backup_tokens_gpu,
            next_token_ids,
            valid_sampled_tokens_count,
            gpu_input_batch.vocab_size,
            num_tokens,
            batch_size,
            sampled_token_ids.stride(0),
            BLOCK_SIZE_TOKENS=BLOCK_SIZE_TOKENS,
871
        )
872
873
874

        return next_token_ids, valid_sampled_tokens_count

875
876
877
878
879
    def prepare_inputs_padded(
        self,
        common_attn_metadata: CommonAttentionMetadata,
        spec_decode_metadata: SpecDecodeMetadata,
        valid_sampled_tokens_count: torch.Tensor,
880
    ) -> tuple[CommonAttentionMetadata, torch.Tensor, torch.Tensor]:
881
882
883
884
885
886
        """
        This function is used to prepare the inputs for speculative decoding
        It updates the common_attn_metadata for speculative decoding,
        but does not consider the rejected tokens. Instead, all tokens
        are included as inputs to the speculator, with the rejected tokens
        used as padding and filtered out later by `token_indices_to_sample`.
887
        No blocking CPU operations should be introduced in this function.
888
        """
889
890
891
892
893
        num_reqs = common_attn_metadata.num_reqs
        device = valid_sampled_tokens_count.device

        token_indices_to_sample = torch.empty(
            (num_reqs,), dtype=torch.int32, device=device
894
        )
895
896
897
        num_rejected_tokens_gpu = torch.empty(
            (num_reqs,), dtype=torch.int32, device=device
        )
898

899
900
901
902
903
904
        grid = (num_reqs,)
        eagle_prepare_inputs_padded_kernel[grid](
            spec_decode_metadata.cu_num_draft_tokens,
            valid_sampled_tokens_count,
            common_attn_metadata.query_start_loc,
            token_indices_to_sample,
905
            num_rejected_tokens_gpu,
906
            num_reqs,
907
        )
908
909

        query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
910
        new_query_len_per_req = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
911
912
913
914
915
916
917

        total_num_tokens = query_start_loc_cpu[-1].item()

        spec_common_attn_metadata = CommonAttentionMetadata(
            query_start_loc=common_attn_metadata.query_start_loc,
            seq_lens=common_attn_metadata.seq_lens,
            query_start_loc_cpu=query_start_loc_cpu,
918
919
            _seq_lens_cpu=common_attn_metadata._seq_lens_cpu,
            _num_computed_tokens_cpu=common_attn_metadata._num_computed_tokens_cpu,
920
921
922
            num_reqs=common_attn_metadata.num_reqs,
            num_actual_tokens=total_num_tokens,
            max_query_len=new_query_len_per_req.max().item(),
923
            max_seq_len=common_attn_metadata.max_seq_len,
924
            block_table_tensor=common_attn_metadata.block_table_tensor,
925
            slot_mapping=common_attn_metadata.slot_mapping[:total_num_tokens],
926
            causal=True,
927
            dcp_local_seq_lens=common_attn_metadata.dcp_local_seq_lens,
928
929
        )

930
931
932
933
934
        return (
            spec_common_attn_metadata,
            token_indices_to_sample,
            num_rejected_tokens_gpu,
        )
935

936
937
938
939
940
941
942
943
944
945
    def propose_tree(
        self,
        batch_size: int,
        # [num_tokens, vocab_size]
        logits: torch.Tensor,
        # [num_tokens]
        positions: torch.Tensor,
        # [num_tokens, hidden_size]
        hidden_states: torch.Tensor,
        common_attn_metadata: CommonAttentionMetadata,
946
947
948
        slot_mappings: dict[str, torch.Tensor]
        | list[dict[str, torch.Tensor]]
        | None = None,
949
    ) -> list[torch.Tensor]:
950
        tree_attn_metadata_builder = self.draft_attn_groups[0].get_metadata_builder()
951
        assert isinstance(tree_attn_metadata_builder, TreeAttentionMetadataBuilder)
952

953
        total_num_drafts = self.cu_drafts_per_level[0]
954
955
        level_num_drafts = total_num_drafts
        # Sample a draft token for each child at the tree root level.
956
        num_children = self.child_drafts_per_level[0]
957
958
959
        if num_children == 1:
            draft_token_ids = logits.argmax(dim=-1).view(batch_size, -1)
        else:
960
961
962
            draft_token_ids = torch.topk(logits, num_children, dim=-1).indices.view(
                batch_size, -1
            )
963
964
965
966
        draft_token_ids_list = [draft_token_ids]
        draft_hidden_states = hidden_states.view(batch_size, 1, -1)

        # Initialize empty tensors for concatenation with the level outputs.
967
968
969
970
971
972
973
974
975
        tree_input_ids = torch.empty(
            0, device=self.input_ids.device, dtype=self.input_ids.dtype
        )
        tree_positions = torch.empty(
            0, device=self.positions.device, dtype=self.positions.dtype
        )
        tree_hidden_states = torch.empty(
            0, device=self.hidden_states.device, dtype=self.hidden_states.dtype
        )
976
977
        # Precompute the draft token positions.
        flattened_draft_positions = (
978
979
            positions.view(batch_size, -1) + self.tree_draft_pos_offsets[:batch_size, :]
        )
980
        tree_depth = len(self.cu_drafts_per_level)
981
        for level in range(tree_depth - 1):
982
983
            # Get draft positions for RoPE.
            draft_positions = positions + (level + 1)
984
            exceeds_max_model_len = (positions + total_num_drafts) >= self.max_model_len
985
986
            # Mask out the position ids that exceed the max model length.
            # Otherwise, we may get out-of-range error in RoPE.
987
            draft_positions = torch.where(
988
989
990
                exceeds_max_model_len,
                0,
                draft_positions,
991
992
            ).view(batch_size, -1)

993
994
            if level_num_drafts > 1:
                # Repeat the positions for each draft at this level.
995
                draft_positions = draft_positions.repeat_interleave(
996
997
                    level_num_drafts, dim=1
                )
998
999
1000
1001

            if num_children > 1:
                # Repeat draft hidden states for each child.
                draft_hidden_states = draft_hidden_states.repeat_interleave(
1002
1003
                    num_children, dim=1
                )
1004
1005

            # Concatenate the draft tokens, positions, and hidden states.
1006
1007
            tree_input_ids = torch.cat([tree_input_ids, draft_token_ids], dim=1)
            tree_positions = torch.cat([tree_positions, draft_positions], dim=1)
1008
            tree_hidden_states = torch.cat(
1009
1010
                [tree_hidden_states, draft_hidden_states], dim=1
            )
1011
1012
1013

            # Build new attention metadata for the next level of drafts.
            # This is necessary to support tree attention.
1014
            query_len = total_num_drafts
1015
1016
            common_attn_metadata = replace(
                common_attn_metadata,
1017
                query_start_loc=query_len * self.arange[: batch_size + 1],
1018
1019
1020
1021
1022
                seq_lens=common_attn_metadata.seq_lens + level_num_drafts,
                num_actual_tokens=batch_size * query_len,
                max_query_len=query_len,
            )
            attn_metadata = tree_attn_metadata_builder.build_for_drafting(
1023
                common_attn_metadata=common_attn_metadata, draft_index=level + 1
1024
1025
            )

1026
            # Apply new attention metadata to all draft layers.
1027
            per_layer_attn_metadata = {}
1028
1029
1030
            for attn_group in self.draft_attn_groups:
                for layer_name in attn_group.layer_names:
                    per_layer_attn_metadata[layer_name] = attn_metadata
1031
1032

            # Consider max model length.
1033
1034
1035
            attn_metadata.max_seq_len = min(
                attn_metadata.max_seq_len, self.max_model_len
            )
1036
1037
1038
1039
1040
            # For the requests that exceed the max model length, we set the
            # sequence length to 1 to minimize their overheads in attention.
            attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1)

            # Compute the slot mapping.
1041
            block_size = tree_attn_metadata_builder.kv_cache_spec.block_size
1042
            query_positions = flattened_draft_positions[:, level : level + query_len]
1043
            block_numbers = query_positions // block_size
1044
            block_ids = attn_metadata.block_table.gather(dim=1, index=block_numbers)
1045
            slot_mapping = block_ids * block_size + query_positions % block_size
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
            # Mask out the slot mappings that exceed the max model length.
            # Otherwise, the KV cache will be inadvertently updated with the
            # padding tokens.
            slot_mapping[exceeds_max_model_len] = PADDING_SLOT_ID
            attn_metadata.slot_mapping = slot_mapping.view(-1)

            # Copy inputs to buffer for cudagraph.
            num_tokens = attn_metadata.num_actual_tokens
            input_ids = tree_input_ids.view(-1)
            self.input_ids[:num_tokens] = input_ids
            self.positions[:num_tokens] = tree_positions.view(-1)
1057
            self.hidden_states[:num_tokens] = tree_hidden_states.view(num_tokens, -1)
1058

1059
1060
1061
1062
            cudagraph_runtime_mode, batch_desc = self.cudagraph_dispatcher.dispatch(
                num_tokens
            )
            num_input_tokens = batch_desc.num_tokens
1063
            # Run the model.
1064
            with set_forward_context(
1065
1066
1067
1068
                per_layer_attn_metadata,
                self.vllm_config,
                num_tokens=num_input_tokens,
                cudagraph_runtime_mode=cudagraph_runtime_mode,
1069
1070
1071
                slot_mapping=self._get_slot_mapping(
                    num_input_tokens, attn_metadata.slot_mapping
                ),
1072
            ):
1073
1074
1075
1076
1077
1078
1079
1080
1081
                last_hidden_states, hidden_states = self.model(
                    input_ids=self.input_ids[:num_input_tokens],
                    positions=self.positions[:num_input_tokens],
                    hidden_states=self.hidden_states[:num_input_tokens],
                    inputs_embeds=None,
                )

            # Get the output hidden states for the draft tokens.
            draft_hidden_states = hidden_states[:num_tokens].view(
1082
1083
                batch_size, query_len, -1
            )[:, -level_num_drafts:]
1084
            draft_last_hidden_states = last_hidden_states[:num_tokens].view(
1085
1086
                batch_size, query_len, -1
            )[:, -level_num_drafts:]
1087
1088
1089

            # Get the output logits for the draft tokens.
            logits = self.model.compute_logits(
1090
1091
                draft_last_hidden_states.reshape(batch_size * level_num_drafts, -1)
            )
1092
1093
1094
1095
1096
1097

            # Sample a draft token for each child at the next tree level.
            num_children = self.child_drafts_per_level[level + 1]
            if num_children == 1:
                draft_token_ids = logits.argmax(dim=-1).view(batch_size, -1)
            else:
1098
1099
1100
                draft_token_ids = torch.topk(logits, num_children, dim=-1).indices.view(
                    batch_size, -1
                )
1101
1102
1103
            draft_token_ids_list.append(draft_token_ids)

            # Update the # drafts counters for the next tree level.
1104
            level_num_drafts = self.cu_drafts_per_level[level + 1] - total_num_drafts
1105
1106
1107
            total_num_drafts = self.cu_drafts_per_level[level + 1]
        return draft_token_ids_list

1108
    def prepare_inputs(
1109
1110
        self,
        common_attn_metadata: CommonAttentionMetadata,
1111
1112
        sampled_token_ids: list[list[int]],
        num_draft_tokens: list[int],
1113
1114
    ) -> tuple[CommonAttentionMetadata, torch.Tensor]:
        """
1115
        This function is used to prepare the inputs for speculative decoding.
1116
1117
1118
1119
1120
1121
        It updates to the common_attn_metadata to account for the rejected
        tokens (and newly sampled tokens). It also returns the token indices
        of the tokens that should be fed to the speculator.
        """
        # E.g.
        #  common_attn_metadata.query_start_loc{_cpu}:
1122
        #       [0, q1, q1 + q2, q1 + q2 + q3]
1123
1124
1125
1126
1127
1128
        #  common_attn_metadata.seq_lens{_cpu}: [s1, s2, s3]
        #  num_rejected_tokens: [n1, n2, n3]
        # This function computes the intermediate values:
        #  num_tokens_per_req: [q1 - n1, q2 - n2, q3 - n3]
        # And returns:
        #  common_attn_metadata.query_start_loc{_cpu}:
1129
        #       [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
1130
        #  common_attn_metadata.seq_lens{_cpu}:
1131
        #       [s1 - n1 + 1, s2 - n2 + 1, s3 - n3 + 1]
1132
        #  token_indices: [0, 1, ..., q1 - n1 - 1,
1133
1134
        #                 q1, q1 + 1, ..., q1 + q2 - n2 - 1,
        #                 q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1]
1135

1136
1137
1138
1139
        num_rejected_tokens = [
            n + 1 - len(sampled_token_ids[i]) if n > 0 else 0
            for i, n in enumerate(num_draft_tokens)
        ]
1140
        num_rejected_tokens = torch.tensor(num_rejected_tokens, dtype=torch.int32)
1141

1142
1143
        device = common_attn_metadata.query_start_loc.device
        query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
1144
        new_seq_lens_cpu = common_attn_metadata.seq_lens_cpu - num_rejected_tokens
1145
1146

        # [0, q1, q1 + q2, q1 + q2 + q3] -> [q1, q2, q3]
1147
        new_query_len_per_req = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
1148
1149
1150
1151
1152
1153
1154
1155
        # [q1, q2, q3] -> [q1 - n1, q2 - n2, q3 - n3]
        new_num_tokens_per_req = new_query_len_per_req - num_rejected_tokens
        new_num_tokens_per_req_np = new_num_tokens_per_req.numpy()

        # [q1 - n1, q2 - n2, q3 - n3] ->
        # [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
        new_query_start_loc_cpu = torch.zeros(
            query_start_loc_cpu.shape,
1156
            dtype=torch.int32,
1157
1158
            pin_memory=is_pin_memory_available(),
        )
1159
1160
1161
1162
1163
1164
1165
1166
1167
        new_query_start_loc_np = new_query_start_loc_cpu.numpy()
        np.cumsum(new_num_tokens_per_req_np, out=new_query_start_loc_np[1:])

        total_num_tokens = new_query_start_loc_np[-1]
        # Example assuming num_tokens_per_req_np = [2, 4, 3]
        # this implies that `new_query_start_locs` is:
        # [0, 2, 6, 9] ->
        # [0, 0, 2, 2, 2, 2, 6, 6, 6]
        #  _r1_  ____r2____  ___r3__
1168
1169
1170
        new_query_start_locs_expanded = np.repeat(
            new_query_start_loc_np[:-1], new_num_tokens_per_req_np
        )
1171
1172
1173
        # [0, 1, 2, 3, 4, 5, 6, 7, 8] ->
        # [0, 1, 0, 1, 2, 3, 0, 1, 2]
        #  _r1_  ____r2____  ___r3__
1174
        token_offsets = (
1175
1176
            self.token_arange_np[:total_num_tokens] - new_query_start_locs_expanded
        )
1177
1178
1179
1180
1181
1182

        # Expand starting positions to match token pattern
        # [0, q1, q1 + q2] ->
        # [0, 0, q1, q1, q1, q1, q1 + q2, q1 + q2, q1 + q2]
        #  _r1_  _____r2_______  ___________r3____________
        old_query_start_locs_expanded = np.repeat(
1183
1184
            query_start_loc_cpu[:-1].numpy(), new_num_tokens_per_req_np
        )
1185
        # Final token indices are:
1186
1187
1188
        # [0, 1,                                // req 1
        #  q1 + 0, q1 + 1, q1 + 2, q1 + 3,       // req 2
        #  q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3
1189
        token_indices_np = token_offsets + old_query_start_locs_expanded
1190
        token_indices = torch.from_numpy(token_indices_np).to(device, non_blocking=True)
1191
1192

        spec_common_attn_metadata = CommonAttentionMetadata(
1193
            query_start_loc=new_query_start_loc_cpu.to(device, non_blocking=True),
1194
1195
            seq_lens=new_seq_lens_cpu.to(device, non_blocking=True),
            query_start_loc_cpu=new_query_start_loc_cpu,
1196
1197
            _seq_lens_cpu=new_seq_lens_cpu,
            _num_computed_tokens_cpu=common_attn_metadata._num_computed_tokens_cpu,
1198
1199
1200
            num_reqs=common_attn_metadata.num_reqs,
            num_actual_tokens=total_num_tokens,
            max_query_len=new_query_len_per_req.max().item(),
1201
            max_seq_len=new_seq_lens_cpu.max().item(),
1202
1203
            block_table_tensor=common_attn_metadata.block_table_tensor,
            slot_mapping=common_attn_metadata.slot_mapping[token_indices],
1204
            causal=True,
1205
            dcp_local_seq_lens=common_attn_metadata.dcp_local_seq_lens,
1206
        )
1207
1208

        return spec_common_attn_metadata, token_indices
1209

1210
    def get_model_name(self, model: nn.Module) -> str:
1211
        if hasattr(model, "module"):  # multi-GPU
1212
1213
1214
            model = model.module
        return model.__class__.__name__

1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
    def _create_draft_vllm_config(self) -> VllmConfig:
        """Return a VllmConfig with kernel-level overrides for the proposer.
        Subclasses may override to apply additional config changes.
        """
        spec_cfg = self.speculative_config
        if spec_cfg.moe_backend is not None:
            return replace(
                self.vllm_config,
                kernel_config=replace(
                    self.vllm_config.kernel_config,
                    moe_backend=spec_cfg.moe_backend,
                ),
            )
        return self.vllm_config

1230
1231
1232
1233
1234
1235
1236
    def _get_model(self) -> nn.Module:
        """
        Default method to call get_model(). Can be overridden by subclasses which
        need to customize model loading.
        """
        from vllm.compilation.backends import set_model_tag

1237
        draft_vllm_config = self._create_draft_vllm_config()
1238
1239
        with set_model_tag("eagle_head"):
            model = get_model(
1240
                vllm_config=draft_vllm_config,
1241
                model_config=self.speculative_config.draft_model_config,
1242
                load_config=self.speculative_config.draft_load_config,
1243
1244
1245
            )
        return model

1246
    def load_model(self, target_model: nn.Module) -> None:
1247
        target_attn_layer_names = set(
1248
1249
1250
1251
            get_layers_from_vllm_config(
                self.vllm_config,
                AttentionLayerBase,  # type: ignore[type-abstract]
            ).keys()
1252
        )
1253

1254
        self.model = self._get_model()
1255

1256
1257
1258
1259
        # Find draft layers (attention layers added by draft model)
        all_attn_layers = get_layers_from_vllm_config(
            self.vllm_config,
            AttentionLayerBase,  # type: ignore[type-abstract]
1260
        )
1261
1262
        self._draft_attn_layer_names = (
            set(all_attn_layers.keys()) - target_attn_layer_names
1263
        )
1264

1265
        if self.supports_mm_inputs:
1266
1267
1268
            # Even if the target model is multimodal, we can also use
            # text-only draft models
            try:
1269
                dummy_input_ids = torch.tensor([[1]], device=self.input_ids.device)
1270
                self.model.embed_input_ids(dummy_input_ids, multimodal_embeddings=None)
1271
1272
1273
            except (NotImplementedError, AttributeError, TypeError):
                logger.warning(
                    "Draft model does not support multimodal inputs, "
1274
1275
                    "falling back to text-only mode"
                )
1276
                self.supports_mm_inputs = False
1277

1278
1279
        if supports_multimodal(target_model):
            # handle multimodality
1280
            assert hasattr(target_model, "config")
1281
1282
1283
            if self.get_model_name(target_model) in [
                "Qwen2_5_VLForConditionalGeneration",
                "Qwen3VLForConditionalGeneration",
1284
                "Qwen3VLMoeForConditionalGeneration",
1285
                "Gemma4ForConditionalGeneration",
1286
                "HunYuanVLForConditionalGeneration",
1287
                "GlmOcrForConditionalGeneration",
1288
1289
                "Qwen3_5ForConditionalGeneration",
                "Qwen3_5MoeForConditionalGeneration",
1290
            ]:
1291
                self.model.config.image_token_index = target_model.config.image_token_id
1292
1293
1294
1295
            elif self.get_model_name(target_model) == "PixtralForConditionalGeneration":
                self.model.config.image_token_index = (
                    target_model.config.vision_config.image_token_id
                )
1296
1297
1298
1299
            elif self.get_model_name(target_model) == "KimiK25ForConditionalGeneration":
                self.model.config.image_token_index = (
                    target_model.config.media_placeholder_token_id
                )
1300
1301
            else:
                self.model.config.image_token_index = (
1302
1303
                    target_model.config.image_token_index
                )
1304
1305
1306
            target_language_model = cast(
                SupportsMultiModal, target_model
            ).get_language_model()
1307
1308
        else:
            target_language_model = target_model
1309

1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
        self._maybe_share_embeddings(target_language_model)
        self._maybe_share_lm_head(target_language_model)

        if self.parallel_drafting and self.pass_hidden_states_to_model:
            assert self.parallel_drafting_hidden_state_tensor is not None
            self.parallel_drafting_hidden_state_tensor.copy_(
                self.model.combine_hidden_states(
                    self.model.mask_hidden.view(3 * self.hidden_size)
                )
                if self.eagle3_use_aux_hidden_state
                else self.model.mask_hidden.view(self.hidden_size)
            )

    def _maybe_share_embeddings(self, target_language_model: nn.Module) -> None:
        """
        Some draft models may not have their own embedding layers, and some may
        have a duplicate copy of the target model's embedding layers. In these cases,
        we share the target model's embedding layers with the draft model to save
        memory.
        """
1330
        if get_pp_group().world_size == 1:
1331
1332
1333
1334
1335
1336
1337
            inner_model = getattr(target_language_model, "model", None)
            if inner_model is None:
                raise AttributeError("Target model does not have 'model' attribute")
            if hasattr(inner_model, "embed_tokens"):
                target_embed_tokens = inner_model.embed_tokens
            elif hasattr(inner_model, "embedding"):
                target_embed_tokens = inner_model.embedding
1338
1339
            else:
                raise AttributeError(
1340
1341
                    "Target model does not have 'embed_tokens' or 'embedding' attribute"
                )
1342

1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
            share_embeddings = False
            if hasattr(self.model, "has_own_embed_tokens"):
                # EAGLE model
                if not self.model.has_own_embed_tokens:
                    share_embeddings = True
                    logger.info(
                        "Detected EAGLE model without its own embed_tokens in the"
                        " checkpoint. Sharing target model embedding weights with the"
                        " draft model."
                    )
                elif (
                    isinstance(target_embed_tokens.weight, torch.Tensor)
                    and isinstance(self.model.model.embed_tokens.weight, torch.Tensor)
1356
1357
1358
                    # TODO: Offload to CPU for comparison to avoid extra GPU memory
                    # usage in CI testing environments with limited GPU memory
                    and torch.equal(
1359
1360
                        target_embed_tokens.weight.cpu(),
                        self.model.model.embed_tokens.weight.cpu(),
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
                    )
                ):
                    share_embeddings = True
                    logger.info(
                        "Detected EAGLE model with embed_tokens identical to the target"
                        " model. Sharing target model embedding weights with the draft"
                        " model."
                    )
                else:
                    logger.info(
                        "Detected EAGLE model with distinct embed_tokens weights. "
                        "Keeping separate embedding weights from the target model."
                    )
1374
            else:
1375
1376
                # MTP model
                share_embeddings = True
1377
                logger.info(
1378
1379
                    "Detected MTP model. "
                    "Sharing target model embedding weights with the draft model."
1380
                )
1381
1382
1383
1384
1385

            if share_embeddings:
                if hasattr(self.model.model, "embed_tokens"):
                    del self.model.model.embed_tokens
                self.model.model.embed_tokens = target_embed_tokens
1386
        else:
1387
            logger.info(
1388
                "The draft model's vocab embedding will be loaded separately"
1389
1390
                " from the target model."
            )
1391

1392
1393
1394
1395
1396
1397
    def _maybe_share_lm_head(self, target_language_model: nn.Module) -> None:
        """
        Some draft models may not have their own LM head, and some may have a
        duplicate copy of the target model's LM head. In these cases, we share
        the target model's LM head with the draft model to save memory.
        """
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
        share_lm_head = False
        if hasattr(self.model, "has_own_lm_head"):
            # EAGLE model
            if not self.model.has_own_lm_head:
                share_lm_head = True
                logger.info(
                    "Detected EAGLE model without its own lm_head in the checkpoint. "
                    "Sharing target model lm_head weights with the draft model."
                )
            elif (
                hasattr(target_language_model, "lm_head")
1409
1410
                and hasattr(target_language_model.lm_head, "weight")
                and hasattr(self.model.lm_head, "weight")
1411
1412
                and isinstance(target_language_model.lm_head.weight, torch.Tensor)
                and isinstance(self.model.lm_head.weight, torch.Tensor)
1413
1414
                # TODO: Offload to CPU for comparison to avoid extra GPU memory
                # usage in CI testing environments with limited GPU memory
1415
                and torch.equal(
1416
1417
                    target_language_model.lm_head.weight.cpu(),
                    self.model.lm_head.weight.cpu(),
1418
                )
1419
            ):
1420
                share_lm_head = True
1421
                logger.info(
1422
1423
                    "Detected EAGLE model with lm_head identical to the target model. "
                    "Sharing target model lm_head weights with the draft model."
1424
                )
1425
1426
            else:
                logger.info(
1427
1428
                    "Detected EAGLE model with distinct lm_head weights. "
                    "Keeping separate lm_head weights from the target model."
1429
                )
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
        else:
            # MTP model
            share_lm_head = True
            logger.info(
                "Detected MTP model. "
                "Sharing target model lm_head weights with the draft model."
            )

        if share_lm_head and hasattr(target_language_model, "lm_head"):
            if hasattr(self.model, "lm_head"):
                del self.model.lm_head
            self.model.lm_head = target_language_model.lm_head
1442

1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
            # MTP models call compute_logits via shared_head.head (a
            # ParallelLMHead inside each MTP layer), not self.model.lm_head.
            # If the checkpoint omits a copy of the lm_head weights at the
            # MTP layer path, shared_head.head stays uninitialised and
            # produces NaN logits. Always share it explicitly.
            inner = getattr(self.model, "model", None)
            layers = getattr(inner, "layers", None) if inner else None
            if layers is not None:
                items = layers.values() if isinstance(layers, nn.ModuleDict) else layers
                for layer in items:
                    sh = getattr(layer, "shared_head", None)
                    if sh is not None and hasattr(sh, "head"):
                        del sh.head
                        sh.head = target_language_model.lm_head
                        logger.info(
                            "Shared target model lm_head with MTP shared_head.head."
                        )

1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
        if self.use_local_argmax_reduction:
            if not hasattr(self.model, "get_top_tokens"):
                raise ValueError(
                    "use_local_argmax_reduction is enabled but draft model "
                    f"{self.model.__class__.__name__} does not implement "
                    "get_top_tokens()."
                )
            # Warn if draft model has vocab remapping, which forces fallback
            # to the full-logits path (negating the optimization).
            if (
                hasattr(self.model, "draft_id_to_target_id")
                and self.model.draft_id_to_target_id is not None
            ):
                logger.warning(
                    "use_local_argmax_reduction is enabled but draft model "
                    "uses draft_id_to_target_id vocab remapping. The "
                    "optimization will be bypassed (falling back to full "
                    "logits gather + argmax)."
                )
            else:
                logger.info(
                    "Using local argmax reduction for draft token generation "
                    "(communication: O(2*tp_size) vs O(vocab_size))."
                )

1486
1487
1488
1489
    @torch.inference_mode()
    def dummy_run(
        self,
        num_tokens: int,
1490
1491
        use_cudagraphs: bool = True,
        is_graph_capturing: bool = False,
1492
        slot_mappings: dict[str, torch.Tensor] | None = None,
1493
    ) -> None:
Rémi Delacourt's avatar
Rémi Delacourt committed
1494
1495
1496
1497
        # FIXME: when using tree-based specdec, adjust number of forward-passes
        # according to the depth of the tree.
        for fwd_idx in range(
            self.num_speculative_tokens if not is_graph_capturing else 1
1498
        ):
Rémi Delacourt's avatar
Rémi Delacourt committed
1499
            if fwd_idx <= 1:
1500
1501
1502
                cudagraph_runtime_mode, num_input_tokens, num_tokens_across_dp = (
                    self._determine_batch_execution_and_padding(
                        num_tokens, use_cudagraphs=use_cudagraphs
1503
                    )
1504
                )
1505

1506
1507
            # Make sure to use EAGLE's own buffer during cudagraph capture.
            if (
1508
                self._draft_attn_layer_names
1509
                and slot_mappings is not None
1510
                and next(iter(self._draft_attn_layer_names)) in slot_mappings
1511
1512
1513
1514
1515
            ):
                slot_mapping_dict = self._get_slot_mapping(num_input_tokens)
            else:
                slot_mapping_dict = slot_mappings or {}

Rémi Delacourt's avatar
Rémi Delacourt committed
1516
1517
1518
1519
1520
            with set_forward_context(
                None,
                self.vllm_config,
                num_tokens=num_input_tokens,
                num_tokens_across_dp=num_tokens_across_dp,
1521
                cudagraph_runtime_mode=cudagraph_runtime_mode,
1522
                slot_mapping=slot_mapping_dict,
Rémi Delacourt's avatar
Rémi Delacourt committed
1523
1524
1525
1526
1527
1528
1529
1530
            ):
                if self.supports_mm_inputs:
                    input_ids = None
                    inputs_embeds = self.inputs_embeds[:num_input_tokens]
                else:
                    input_ids = self.input_ids[:num_input_tokens]
                    inputs_embeds = None

1531
                kwargs = dict(
Rémi Delacourt's avatar
Rémi Delacourt committed
1532
1533
1534
1535
                    input_ids=input_ids,
                    positions=self._get_positions(num_input_tokens),
                    inputs_embeds=inputs_embeds,
                )
1536
1537
1538
                if self.pass_hidden_states_to_model:
                    kwargs["hidden_states"] = self.hidden_states[:num_input_tokens]
                self.model(**kwargs)
1539

1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
    def _get_eagle3_use_aux_hidden_state_from_config(self) -> bool:
        """
        Some eagle3 heads (e.g., nvidia/gpt-oss-120b-Eagle3-v2) do not use auxiliary
        hidden states and directly uses the last layer output just like eagle1.
        They might indicate this by setting "use_aux_hidden_state" to False
        inside the "eagle_config" dict of their hf_config.
        """
        if self.method != "eagle3":
            return False
        # Assume that eagle3 heads use aux hidden states by default
        use_aux_hidden_state = True
        eagle_config = getattr(self.draft_model_config.hf_config, "eagle_config", None)
        if eagle_config is not None:
            use_aux_hidden_state = eagle_config.get("use_aux_hidden_state", True)
        return use_aux_hidden_state

1556
    def validate_same_kv_cache_group(self, kv_cache_config: KVCacheConfig) -> None:
1557
        """
1558
1559
        Validate that all drafting layers belong to the same KVCacheGroup.
        Need this assumption to ensure all drafting layers can use the
1560
1561
1562
1563
1564
1565
1566
        same AttentionMetadata.
        May extend to multiple AttentionMetadata in the future.
        """
        kv_cache_groups: dict[str, int] = {}
        for id, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups):
            for layer_name in kv_cache_group.layer_names:
                kv_cache_groups[layer_name] = id
1567
1568
1569
1570
1571
        assert (
            len(
                set(
                    [
                        kv_cache_groups[layer_name]
1572
                        for layer_name in self._draft_attn_layer_names
1573
1574
1575
1576
                    ]
                )
            )
            == 1
1577
        ), "All drafting layers should belong to the same kv cache group"
1578

1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
    def initialize_attn_backend(
        self,
        kv_cache_config: KVCacheConfig,
        kernel_block_sizes: list[int] | None = None,
    ) -> None:
        """
        Initialize AttentionGroups for draft layers using kv_cache_config.
        Called from the model runner's initialize_metadata_builders.
        """
        all_attn_layers = get_layers_from_vllm_config(
            self.vllm_config,
            AttentionLayerBase,  # type: ignore[type-abstract]
        )

        # Find which kv_cache_group the draft layers belong to
        self.validate_same_kv_cache_group(kv_cache_config)
        kv_cache_spec = None
        for gid, group in enumerate(kv_cache_config.kv_cache_groups):
            if self._draft_attn_layer_names & set(group.layer_names):
                self.kv_cache_gid = gid
                kv_cache_spec = group.kv_cache_spec
                break

        attention_groups: dict[tuple[str, str], AttentionGroup] = {}
        if kv_cache_spec is not None:
            for layer_name in self._draft_attn_layer_names:
                attn_backend = all_attn_layers[layer_name].get_attn_backend()
                backend_key = attn_backend.full_cls_name()
                if backend_key not in attention_groups:
                    layer_kv_cache_spec = kv_cache_spec
                    if isinstance(layer_kv_cache_spec, UniformTypeKVCacheSpecs):
                        layer_kv_cache_spec = layer_kv_cache_spec.kv_cache_specs[
                            layer_name
                        ]

                    kernel_block_size = (
                        kernel_block_sizes[self.kv_cache_gid]
                        if kernel_block_sizes is not None
                        and self.kv_cache_gid < len(kernel_block_sizes)
                        else None
                    )
                    attn_group = AttentionGroup(
                        backend=attn_backend,
                        layer_names=[layer_name],
                        kv_cache_spec=layer_kv_cache_spec,
                        kv_cache_group_id=self.kv_cache_gid,
                    )
                    attn_group.create_metadata_builders(
                        self.vllm_config,
                        self.device,
                        kernel_block_size=kernel_block_size,
                    )
                    attention_groups[backend_key] = attn_group
                else:
                    attention_groups[backend_key].layer_names.append(layer_name)

        self.draft_attn_groups = list(attention_groups.values())
1636
1637
1638
1639
        self.block_size = (
            self.draft_attn_groups[0].get_metadata_builder().kv_cache_spec.block_size
        )
        logger.debug("Using block size %d for drafting layers", self.block_size)
1640

1641
    def _determine_batch_execution_and_padding(
Rémi Delacourt's avatar
Rémi Delacourt committed
1642
        self,
1643
1644
1645
1646
1647
1648
        num_tokens: int,
        use_cudagraphs: bool = True,
    ) -> tuple[CUDAGraphMode, int, torch.Tensor | None]:
        cudagraph_mode, batch_desc = self.cudagraph_dispatcher.dispatch(
            num_tokens,
            valid_modes=({CUDAGraphMode.NONE} if not use_cudagraphs else None),
Rémi Delacourt's avatar
Rémi Delacourt committed
1649
        )
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
        num_tokens_padded = batch_desc.num_tokens

        # Extra coordination when running data-parallel since we need to
        # coordinate across ranks
        # TODO(Flechman): support DBO ubatching
        should_ubatch, num_tokens_across_dp = False, None
        if self.vllm_config.parallel_config.data_parallel_size > 1:
            should_ubatch, num_tokens_across_dp, synced_cudagraph_mode = (
                coordinate_batch_across_dp(
                    num_tokens_unpadded=num_tokens,
                    parallel_config=self.vllm_config.parallel_config,
                    allow_microbatching=False,
                    num_tokens_padded=num_tokens_padded,
                    cudagraph_mode=cudagraph_mode.value,
                )
            )
            assert not should_ubatch, "DBO ubatching not implemented for EAGLE"

            # Extract DP-synced values
            if num_tokens_across_dp is not None:
                dp_rank = self.dp_rank
                num_tokens_padded = int(num_tokens_across_dp[dp_rank].item())
                # Re-dispatch with DP padding so we have the correct
                # batch_descriptor
                cudagraph_mode, batch_desc = self.cudagraph_dispatcher.dispatch(
                    num_tokens_padded,
                    valid_modes={CUDAGraphMode(synced_cudagraph_mode)},
                )
                # Assert to make sure the agreed upon token count is correct
                # otherwise num_tokens_across_dp will no-longer be valid
                assert batch_desc.num_tokens == num_tokens_padded
                num_tokens_across_dp[dp_rank] = num_tokens_padded
Rémi Delacourt's avatar
Rémi Delacourt committed
1682

1683
        return cudagraph_mode, num_tokens_padded, num_tokens_across_dp
Rémi Delacourt's avatar
Rémi Delacourt committed
1684

1685

1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
class EagleProposer(SpecDecodeBaseProposer):
    def __init__(
        self,
        vllm_config: VllmConfig,
        device: torch.device,
        runner=None,
    ):
        super().__init__(
            vllm_config,
            device,
            pass_hidden_states_to_model=True,
            runner=runner,
        )


1701
1702
1703
1704
# NOTE(woosuk): Currently, the below code is not used and we always use argmax
# to sample the draft tokens. We will use this after we find a way to manage
# the draft prob tensor.
# Refer to https://github.com/vllm-project/vllm/pull/16899 for the details.
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
# FIXME(woosuk): The logic here is duplicated with the main sampling code.
# We should refactor this to reuse the same sampling implementation.
def compute_probs_and_sample_next_token(
    logits: torch.Tensor,
    sampling_metadata: SamplingMetadata,
) -> tuple[torch.Tensor, torch.Tensor]:
    if sampling_metadata.all_greedy:
        # For greedy requests, draft_probs is not used in rejection sampling.
        # Therefore, we can just return the logits.
        probs = logits
        next_token_ids = logits.argmax(dim=-1)
        return next_token_ids, probs

1718
1719
1720
1721
1722
1723
1724
1725
1726
    assert sampling_metadata.temperature is not None

    # Use epsilon comparison to detect greedy sampling (temperature ~ 0.0)
    # consistent with sampler.py's _SAMPLING_EPS threshold
    temperature = sampling_metadata.temperature
    # Avoid division by zero if there are greedy requests.
    if not sampling_metadata.all_random:
        is_greedy = temperature < _SAMPLING_EPS
        temperature = torch.where(is_greedy, 1.0, temperature)
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
    logits.div_(temperature.view(-1, 1))
    probs = logits.softmax(dim=-1, dtype=torch.float32)

    # NOTE(woosuk): Currently, we ignore most of the sampling parameters in
    # generating the draft tokens. We only use the temperature. While this
    # could degrade the acceptance rate, it does not affect the distribution
    # of the generated tokens after rejection sampling.

    # TODO(woosuk): Consider seeds.
    q = torch.empty_like(probs)
    q.exponential_()
1738
1739
1740
    # NOTE(woosuk): We shouldn't use `probs.div_(q)` because the draft_probs
    # will be used later for rejection sampling.
    next_token_ids = probs.div(q).argmax(dim=-1).view(-1)
1741
1742
    if not sampling_metadata.all_random:
        greedy_token_ids = probs.argmax(dim=-1)
1743
        next_token_ids = torch.where(is_greedy, greedy_token_ids, next_token_ids)
1744
    return next_token_ids, probs