eagle.py 76.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
import ast
from dataclasses import replace
5
from importlib.util import find_spec
6
from typing import cast
7

8
import numpy as np
9
10
11
import torch
import torch.nn as nn

12
13
14
15
16
from vllm.config import (
    CUDAGraphMode,
    VllmConfig,
    get_layers_from_vllm_config,
)
17
from vllm.distributed.parallel_state import get_pp_group
18
from vllm.forward_context import set_forward_context
19
from vllm.logger import init_logger
20
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
21
from vllm.model_executor.model_loader import get_model
22
from vllm.model_executor.models import supports_multimodal
23
from vllm.model_executor.models.interfaces import SupportsMultiModal
24
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
25
from vllm.multimodal import MULTIMODAL_REGISTRY
26
from vllm.platforms import current_platform
27
from vllm.triton_utils import triton
28
from vllm.utils.platform_utils import is_pin_memory_available
29
from vllm.v1.attention.backend import CommonAttentionMetadata
30
from vllm.v1.attention.backends.registry import AttentionBackendEnum
31
32
33
34
from vllm.v1.attention.backends.tree_attn import (
    TreeAttentionMetadata,
    TreeAttentionMetadataBuilder,
)
35
from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata
36
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
37
from vllm.v1.kv_cache_interface import KVCacheConfig, UniformTypeKVCacheSpecs
38
from vllm.v1.sample.metadata import SamplingMetadata
39
from vllm.v1.sample.sampler import _SAMPLING_EPS
40
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
41
from vllm.v1.spec_decode.utils import (
42
43
44
    PADDING_SLOT_ID,
    compute_new_slot_mapping,
    copy_and_expand_eagle_inputs_kernel,
45
46
    eagle_prepare_inputs_padded_kernel,
    eagle_prepare_next_token_padded_kernel,
47
    extend_all_queries_by_N,
48
)
49
from vllm.v1.utils import CpuGpuBuffer
Rémi Delacourt's avatar
Rémi Delacourt committed
50
from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
51
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
52
from vllm.v1.worker.utils import AttentionGroup
53

54
55
logger = init_logger(__name__)

56

57
class SpecDecodeBaseProposer:
58
59
60
61
    def __init__(
        self,
        vllm_config: VllmConfig,
        device: torch.device,
62
        pass_hidden_states_to_model: bool,
Jiayi Yao's avatar
Jiayi Yao committed
63
        runner=None,
64
65
    ):
        self.vllm_config = vllm_config
66
        assert vllm_config.speculative_config is not None
67
68
69
        self.speculative_config = vllm_config.speculative_config
        self.draft_model_config = self.speculative_config.draft_model_config
        self.method = self.speculative_config.method
70
        self.pass_hidden_states_to_model = pass_hidden_states_to_model
71

Jiayi Yao's avatar
Jiayi Yao committed
72
        self.runner = runner
73
        self.device = device
74
        self.dtype = vllm_config.model_config.dtype
75
        self.max_model_len = vllm_config.model_config.max_model_len
Rémi Delacourt's avatar
Rémi Delacourt committed
76
        self.dp_rank = vllm_config.parallel_config.data_parallel_rank
77
        self.num_speculative_tokens = self.speculative_config.num_speculative_tokens
78

79
80
81
82
        # We need to get the hidden size from the draft model config because
        # the draft model's hidden size can be different from the target model's
        # hidden size (e.g., Llama 3.3 70B).
        self.hidden_size = self.draft_model_config.get_hidden_size()
83
        self.inputs_embeds_size = self.draft_model_config.get_inputs_embeds_size()
84

85
86
87
88
89
90
91
92
93
94
95
96
97
98
        # Unifying eagle, draft model, and parallel drafting support
        self.parallel_drafting: bool = self.speculative_config.parallel_drafting
        self.extra_slots_per_request = (
            1 if not self.parallel_drafting else self.num_speculative_tokens
        )
        self.net_num_new_slots_per_request = self.extra_slots_per_request - (
            1 if self.pass_hidden_states_to_model else 0
        )
        self.needs_extra_input_slots = self.net_num_new_slots_per_request > 0

        self.parallel_drafting_token_id: int = 0
        self.parallel_drafting_hidden_state_tensor: torch.Tensor | None = None
        if self.parallel_drafting:
            self._init_parallel_drafting_params()
99
100
101
        self.use_local_argmax_reduction: bool = (
            self.speculative_config.use_local_argmax_reduction
        )
102
103

        max_batch_size = vllm_config.scheduler_config.max_num_seqs
104
        self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
105
106
        self.token_arange_np = np.arange(self.max_num_tokens)

107
108
109
        # Multi-modal data support
        self.mm_registry = MULTIMODAL_REGISTRY
        self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
110
            vllm_config.model_config
111
        )
112

113
114
        self.draft_attn_groups: list[AttentionGroup] = []
        self.kv_cache_gid: int = -1
115
116
117
        self.eagle3_use_aux_hidden_state: bool = (
            self._get_eagle3_use_aux_hidden_state_from_config()
        )
118

119
        self.compilation_config = self.vllm_config.compilation_config
120
121
122
123
124
125

        # Cudagraph dispatcher for PIECEWISE-only dispatching in eagle.
        # Keys are initialized later via initialize_cudagraph_keys() called from
        # gpu_model_runner._check_and_update_cudagraph_mode after
        # adjust_cudagraph_sizes_for_spec_decode is called.
        self.cudagraph_dispatcher = CudagraphDispatcher(self.vllm_config)
126

127
        # persistent buffers for cuda graph
128
129
130
        self.input_ids = torch.zeros(
            self.max_num_tokens, dtype=torch.int32, device=device
        )
131
132
133
        # Use draft model's M-RoPE setting, not target model's
        # Draft models may be text-only even if target is multimodal
        self.uses_mrope = self.draft_model_config.uses_mrope
134
135
        self.uses_xdrope_dim = self.vllm_config.model_config.uses_xdrope_dim
        self.draft_uses_xdrope_dim = self.draft_model_config.uses_xdrope_dim
136
        if self.uses_mrope:
137
138
139
140
141
142
143
144
145
146
            # NOTE: `mrope_positions` is implemented with one additional dummy
            # position on purpose to make it non-contiguous so that it can work
            # with torch compile.
            # See detailed explanation in https://github.com/vllm-project/vllm/pull/12128#discussion_r1926431923

            # NOTE: When M-RoPE is enabled, position ids are 3D regardless of
            # the modality of inputs. For text-only inputs, each dimension has
            # identical position IDs, making M-RoPE functionally equivalent to
            # 1D-RoPE.
            # See page 5 of https://arxiv.org/abs/2409.12191
147
            self.mrope_positions = torch.zeros(
148
                (3, self.max_num_tokens + 1), dtype=torch.int64, device=device
149
            )
150
151
152
153
154
155
        elif self.uses_xdrope_dim > 0 and self.draft_uses_xdrope_dim > 0:
            self.xdrope_positions = torch.zeros(
                (self.uses_xdrope_dim, self.max_num_tokens + 1),
                dtype=torch.int64,
                device=device,
            )
156
157
        else:
            # RoPE need (max_num_tokens,)
158
159
160
            self.positions = torch.zeros(
                self.max_num_tokens, dtype=torch.int64, device=device
            )
161
        self.hidden_states = torch.zeros(
162
163
            (self.max_num_tokens, self.hidden_size), dtype=self.dtype, device=device
        )
164

165
166
167
        # Will be set when we initialize the attention backend
        self.block_size: int = -1

168
169
170
        # We need +1 here because the arange is used to set query_start_loc,
        # which has one more element than batch_size.
        max_num_slots_for_arange = max(max_batch_size + 1, self.max_num_tokens)
171
172
173
        self.arange = torch.arange(
            max_num_slots_for_arange, device=device, dtype=torch.int32
        )
174

175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
        if self.needs_extra_input_slots:
            self._raise_if_padded_drafter_batch_disabled()
            self._raise_if_multimodal()
            self._raise_if_mrope()

        self.is_rejected_token_mask: torch.Tensor | None = None
        self.is_masked_token_mask: torch.Tensor | None = None
        if self.needs_extra_input_slots:
            # For draft models and parallel drafting, we need to keep track of
            # which tokens are rejected to update the slot mapping with padding slots.
            self.is_rejected_token_mask = torch.zeros(
                (self.max_num_tokens,), dtype=torch.bool, device=device
            )
            # For parallel drafting, we also need to keep track of which tokens
            # are parallel-padding tokens used to sample at later positions.
            # We populate this tensor even when using draft models for simplicity.
            self.is_masked_token_mask = torch.zeros(
                (self.max_num_tokens,), dtype=torch.bool, device=device
            )

195
        self.inputs_embeds = torch.zeros(
196
197
198
            (self.max_num_tokens, self.inputs_embeds_size),
            dtype=self.dtype,
            device=device,
199
        )
200

201
202
203
204
205
        self.backup_next_token_ids = CpuGpuBuffer(
            max_batch_size,
            dtype=torch.int32,
            pin_memory=is_pin_memory_available(),
            device=device,
206
207
            with_numpy=True,
        )
208

209
210
211
212
        self._slot_mapping_buffer = torch.zeros(
            self.max_num_tokens, dtype=torch.int64, device=device
        )

213
        # Determine allowed attention backends once during initialization.
214
        self.allowed_attn_types: tuple | None = None
215
        if current_platform.is_rocm():
216
217
218
219
220
221
            from vllm.v1.attention.backends.rocm_attn import RocmAttentionMetadata

            rocm_types = [
                TritonAttentionMetadata,
                RocmAttentionMetadata,
            ]
222
            # ROCM_AITER_FA is an optional backend
223
224
225
226
227
228
229
            # We check is_enabled() here to avoid importing the backend module during
            # auto-discovery when VLLM_ROCM_USE_AITER=0, which would trigger aiter
            # import and JIT compilation warnings. Explicit backend selection via
            # attention_config still works because the backend module is loaded
            # directly when selected, not through this auto-discovery path.
            # Check if backend module exists to allow explicit selection
            if find_spec(
230
231
                AttentionBackendEnum.ROCM_AITER_FA.get_path(include_classname=False)
            ):
232
                from vllm.v1.attention.backends.rocm_aiter_fa import (
233
234
235
                    AiterFlashAttentionMetadata,
                )

236
                rocm_types.append(AiterFlashAttentionMetadata)
237
238

            # TRITON_MLA backend support for MLA models (e.g., DeepSeek)
239
240
241
            from vllm.model_executor.layers.attention.mla_attention import (
                MLACommonMetadata,
            )
242
243
244

            rocm_types.append(MLACommonMetadata)

245
246
247
248
249
            # FlexAttention backend support
            from vllm.v1.attention.backends.flex_attention import FlexAttentionMetadata

            rocm_types.append(FlexAttentionMetadata)

250
251
            self.allowed_attn_types = tuple(rocm_types)

252
253
        # Parse the speculative token tree.
        spec_token_tree = self.speculative_config.speculative_token_tree
254
        assert spec_token_tree is not None
255
        self.tree_choices: list[tuple[int, ...]] = ast.literal_eval(spec_token_tree)
256
257
258
259
260
261
262
263
        tree_depth = len(self.tree_choices[-1])
        # Precompute per-level properties of the tree.
        num_drafts_per_level = [0] * tree_depth
        for node in self.tree_choices:
            num_drafts_per_level[len(node) - 1] += 1
        self.cu_drafts_per_level = [num_drafts_per_level[0]]
        self.child_drafts_per_level = [num_drafts_per_level[0]]
        for level in range(1, tree_depth):
264
265
266
267
268
269
            self.cu_drafts_per_level.append(
                self.cu_drafts_per_level[-1] + num_drafts_per_level[level]
            )
            self.child_drafts_per_level.append(
                num_drafts_per_level[level] // num_drafts_per_level[level - 1]
            )
270
271
        # Precompute draft position offsets in flattened tree.
        self.tree_draft_pos_offsets = torch.arange(
272
            1, len(self.tree_choices) + 1, device=device, dtype=torch.int32
273
274
        ).repeat(max_batch_size, 1)

275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
    def _raise_if_padded_drafter_batch_disabled(self):
        if self.speculative_config.disable_padded_drafter_batch:
            raise NotImplementedError(
                "Speculative Decoding with draft models or parallel drafting only "
                "supports padded drafter batch. Please unset "
                "disable_padded_drafter_batch in the speculative_config."
            )

    def _raise_if_multimodal(self):
        if self.supports_mm_inputs:
            raise NotImplementedError(
                "Speculative Decoding with draft models or parallel drafting "
                "does not support multimodal models yet"
            )

    def _raise_if_mrope(self):
        if self.draft_model_config.uses_mrope:
            raise NotImplementedError(
                "Speculative Decoding with draft models or parallel drafting "
                "does not support M-RoPE yet"
            )

    def _init_parallel_drafting_params(self):
        # For parallel drafting, we need the token ID to use for masked slots
        # And for EAGLE + parallel drafting, we need the hidden state tensor to use
        # for those masked slots.

        model_hf_config = self.draft_model_config.hf_config
        if hasattr(model_hf_config, "pard_token"):
            self.parallel_drafting_token_id = model_hf_config.pard_token
        elif hasattr(model_hf_config, "ptd_token_id"):
            self.parallel_drafting_token_id = model_hf_config.ptd_token_id
        else:
            raise ValueError(
                "For parallel drafting, the draft model config must have "
                "`pard_token` or `ptd_token_id` specified in its config.json."
            )

        if self.pass_hidden_states_to_model:
            self.parallel_drafting_hidden_state_tensor = torch.empty(
                self.hidden_size, dtype=self.dtype, device=self.device
            )

318
319
320
    def _get_positions(self, num_tokens: int):
        if self.uses_mrope:
            return self.mrope_positions[:, :num_tokens]
321
322
        if self.uses_xdrope_dim > 0 and self.draft_uses_xdrope_dim > 0:
            return self.xdrope_positions[:, :num_tokens]
323
324
325
326
327
        return self.positions[:num_tokens]

    def _set_positions(self, num_tokens: int, positions: torch.Tensor):
        if self.uses_mrope:
            self.mrope_positions[:, :num_tokens] = positions
328
329
        elif self.uses_xdrope_dim > 0 and self.draft_uses_xdrope_dim > 0:
            self.xdrope_positions[:, :num_tokens] = positions
330
        else:
331
332
333
334
335
            # Convert M-RoPE positions if target model uses M-RoPE
            # but draft doesn't, For text inputs, all M-RoPE
            # dimensions are identical
            if self.vllm_config.model_config.uses_mrope:
                positions = positions[0]
336
337
            self.positions[:num_tokens] = positions

338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
    def _get_slot_mapping(
        self,
        num_tokens: int,
        slot_mapping: torch.Tensor | None = None,
    ) -> dict[str, torch.Tensor]:
        """Return slot_mapping dict for EAGLE layers.

        If slot_mapping is provided, copies it into the buffer first.
        """
        if slot_mapping is not None:
            num_actual = slot_mapping.shape[0]
            self._slot_mapping_buffer[:num_actual].copy_(slot_mapping)
            if num_tokens > num_actual:
                self._slot_mapping_buffer[num_actual:num_tokens].fill_(PADDING_SLOT_ID)

        view = self._slot_mapping_buffer[:num_tokens]
354
        return {name: view for name in self._draft_attn_layer_names}
355

356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
    def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode) -> None:
        """Initialize cudagraph dispatcher keys for eagle.

        Eagle only supports PIECEWISE cudagraphs (via mixed_mode).
        This should be called after adjust_cudagraph_sizes_for_spec_decode.
        """
        if (
            not self.speculative_config.enforce_eager
            and cudagraph_mode.mixed_mode()
            in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL]
        ):
            eagle_cudagraph_mode = CUDAGraphMode.PIECEWISE
        else:
            eagle_cudagraph_mode = CUDAGraphMode.NONE

        self.cudagraph_dispatcher.initialize_cudagraph_keys(eagle_cudagraph_mode)

373
374
375
376
377
378
    def _greedy_sample(self, hidden_states: torch.Tensor) -> torch.Tensor:
        """Greedy-sample draft tokens from hidden states."""
        if self.use_local_argmax_reduction:
            return self.model.get_top_tokens(hidden_states)
        return self.model.compute_logits(hidden_states).argmax(dim=-1)

379
380
381
382
    def propose(
        self,
        # [num_tokens]
        target_token_ids: torch.Tensor,
383
        # [num_tokens] or [3, num_tokens] when M-RoPE is enabled
384
385
386
387
388
        target_positions: torch.Tensor,
        # [num_tokens, hidden_size]
        target_hidden_states: torch.Tensor,
        # [batch_size]
        next_token_ids: torch.Tensor,
389
        token_indices_to_sample: torch.Tensor | None,
390
        common_attn_metadata: CommonAttentionMetadata,
391
        sampling_metadata: SamplingMetadata,
392
        mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None,
393
        num_rejected_tokens_gpu: torch.Tensor | None = None,
394
395
396
        slot_mappings: dict[str, torch.Tensor]
        | list[dict[str, torch.Tensor]]
        | None = None,
397
    ) -> torch.Tensor:
398
        batch_size = common_attn_metadata.batch_size()
399

400
401
402
        if self.method == "eagle3":
            assert isinstance(self.model, Eagle3LlamaForCausalLM)
            target_hidden_states = self.model.combine_hidden_states(
403
404
                target_hidden_states
            )
405
            assert target_hidden_states.shape[-1] == self.hidden_size
406

407
        num_tokens, token_indices_to_sample, common_attn_metadata = (
408
409
410
411
            self.set_inputs_first_pass(
                target_token_ids=target_token_ids,
                next_token_ids=next_token_ids,
                target_positions=target_positions,
412
413
                target_hidden_states=target_hidden_states,
                token_indices_to_sample=token_indices_to_sample,
414
415
416
417
                cad=common_attn_metadata,
                num_rejected_tokens_gpu=num_rejected_tokens_gpu,
            )
        )
418

419
        assert self.runner is not None
Jiayi Yao's avatar
Jiayi Yao committed
420

421
422
423
424
        per_layer_attn_metadata: dict[str, object] = {}
        for attn_group in self.draft_attn_groups:
            attn_metadata = attn_group.get_metadata_builder().build_for_drafting(
                common_attn_metadata=common_attn_metadata, draft_index=0
425
            )
426
427
            for layer_name in attn_group.layer_names:
                per_layer_attn_metadata[layer_name] = attn_metadata
428

429
430
        cudagraph_runtime_mode, num_input_tokens, num_tokens_across_dp = (
            self._determine_batch_execution_and_padding(num_tokens)
Rémi Delacourt's avatar
Rémi Delacourt committed
431
432
        )

433
434
435
        if self.supports_mm_inputs:
            mm_embeds, is_mm_embed = mm_embed_inputs or (None, None)

436
            self.inputs_embeds[:num_tokens] = self.model.embed_input_ids(
437
438
439
                self.input_ids[:num_tokens],
                multimodal_embeddings=mm_embeds,
                is_multimodal=is_mm_embed,
440
            )
441

442
            input_ids = None
443
            inputs_embeds = self.inputs_embeds[:num_input_tokens]
444
445
        else:
            input_ids = self.input_ids[:num_input_tokens]
446
            inputs_embeds = None
447

448
449
450
451
452
453
454
455
        model_kwargs = {
            "input_ids": input_ids,
            "positions": self._get_positions(num_input_tokens),
            "inputs_embeds": inputs_embeds,
        }
        if self.pass_hidden_states_to_model:
            model_kwargs["hidden_states"] = self.hidden_states[:num_input_tokens]

456
        with set_forward_context(
457
458
459
            per_layer_attn_metadata,
            self.vllm_config,
            num_tokens=num_input_tokens,
Rémi Delacourt's avatar
Rémi Delacourt committed
460
            num_tokens_across_dp=num_tokens_across_dp,
461
            cudagraph_runtime_mode=cudagraph_runtime_mode,
462
463
464
            slot_mapping=self._get_slot_mapping(
                num_input_tokens, common_attn_metadata.slot_mapping
            ),
465
        ):
466
467
            ret_hidden_states = self.model(**model_kwargs)
            if not self.model_returns_tuple():
Jiayi Yao's avatar
Jiayi Yao committed
468
                last_hidden_states = ret_hidden_states
469
                hidden_states = last_hidden_states
Jiayi Yao's avatar
Jiayi Yao committed
470
471
            else:
                last_hidden_states, hidden_states = ret_hidden_states
472

473
        sample_hidden_states = last_hidden_states[token_indices_to_sample]
474
475

        # Early exit if there is only one draft token to be generated.
476
        if self.num_speculative_tokens == 1 or self.parallel_drafting:
477
            draft_token_ids = self._greedy_sample(sample_hidden_states)
478
            return draft_token_ids.view(-1, self.num_speculative_tokens)
479

480
        if self.uses_mrope:
481
            positions = self.mrope_positions[:, token_indices_to_sample]
482
        else:
483
            positions = self.positions[token_indices_to_sample]
484
        hidden_states = hidden_states[token_indices_to_sample]
485
486

        if isinstance(attn_metadata, TreeAttentionMetadata):
487
488
            # Draft using tree attention - requires full logits for top-k
            logits = self.model.compute_logits(sample_hidden_states)
489
490
491
492
493
494
            draft_token_ids_list = self.propose_tree(
                batch_size=batch_size,
                logits=logits,
                positions=positions,
                hidden_states=hidden_states,
                common_attn_metadata=common_attn_metadata,
495
                slot_mappings=slot_mappings,
496
497
498
499
            )
            # [batch_size, num_tree_tokens]
            return torch.cat(draft_token_ids_list, dim=1)

500
        draft_token_ids = self._greedy_sample(sample_hidden_states)
501

502
503
504
        if self.allowed_attn_types is not None and not isinstance(
            attn_metadata, self.allowed_attn_types
        ):
505
506
507
508
            raise ValueError(
                f"Unsupported attention metadata type for speculative "
                "decoding with num_speculative_tokens > 1: "
                f"{type(attn_metadata)}. Supported types are: "
509
510
                f"{self.allowed_attn_types}"
            )
511

512
513
514
        # Generate the remaining draft tokens.
        draft_token_ids_list = [draft_token_ids]

515
516
        cudagraph_runtime_mode, input_batch_size, batch_size_across_dp = (
            self._determine_batch_execution_and_padding(batch_size)
517
        )
518
519
520

        common_attn_metadata.num_actual_tokens = batch_size
        common_attn_metadata.max_query_len = 1
521
        common_attn_metadata.query_start_loc = self.arange[: batch_size + 1]
522
        common_attn_metadata.query_start_loc_cpu = torch.from_numpy(
523
524
            self.token_arange_np[: batch_size + 1]
        ).clone()
525
526
527
528
529
530
531
532
533
534
535

        # In padded drafter batch, we need to adjust the sequence lengths
        # to remove the "padding" (i.e. rejected tokens).
        # Only apply this adjustment when we have rejected tokens
        # (i.e., not the first proposal).
        if self.num_speculative_tokens > 1 and num_rejected_tokens_gpu is not None:
            common_attn_metadata.seq_lens -= num_rejected_tokens_gpu
            # Invalidate the CPU-side shadows to avoid H<>D sync.
            common_attn_metadata._seq_lens_cpu = None
            common_attn_metadata._num_computed_tokens_cpu = None

536
        for token_index in range(self.num_speculative_tokens - 1):
537
            # Update the inputs.
538
539
540
            # cast to int32 is crucial when eagle model is compiled.
            # tensor.argmax() returns int64 by default.
            input_ids = draft_token_ids_list[-1].int()
541
542
543
544
545
546
547
548
549
550
551
552
553
            if self.uses_mrope:
                positions += 1
                # NOTE(woosuk): We should handle the case where the draft model
                # generates tokens beyond the max model length.
                # Since it is complex to remove such requests from the batch,
                # we keep them in the batch but adjust the position ids
                # and slot mappings to avoid the
                # out-of-range access during the model execution.
                # The draft tokens generated with this adjustment
                # should be ignored.
                exceeds_max_model_len = positions[0] >= self.max_model_len
                # Mask out the position ids that exceed the max model length.
                # Otherwise, we may get out-of-range error in RoPE.
554
555
556
557
558
                clamped_positions = torch.where(
                    exceeds_max_model_len.unsqueeze(0),
                    torch.zeros_like(positions),
                    positions,
                )
559
560
561
            else:
                positions += 1
                exceeds_max_model_len = positions >= self.max_model_len
562
                clamped_positions = torch.where(exceeds_max_model_len, 0, positions)
563
564
565
            # For data integrity when async scheduling, we shouldn't use in place
            # operations in case they are modified in next step's `prepare_input`
            # of main model.
566
            # Increment the sequence lengths.
567
            common_attn_metadata.seq_lens += 1
568
569
            # For the requests that exceed the max model length, we set the
            # sequence length to 1 to minimize their overheads in attention.
570
            common_attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1)
571
572
573
574
575
576
            # Increment the maximum sequence length. We increment max_seq_len
            # unconditionally even though some seq_lens may have been capped above,
            # as max_seq_len serves as an upper bound for sequence lengths.
            common_attn_metadata.max_seq_len = min(
                common_attn_metadata.max_seq_len + 1, self.max_model_len
            )
577

578
579
580
581
582
583
            # Also update the CPU-side shadow; NOTE: this is hacky and should be
            # removed in when common_attn_metadata.seq_lens_cpu is deprecated.
            if common_attn_metadata._seq_lens_cpu is not None:
                common_attn_metadata._seq_lens_cpu += 1
            if common_attn_metadata._num_computed_tokens_cpu is not None:
                common_attn_metadata._num_computed_tokens_cpu += 1
584

585
            # Compute the slot mapping.
586
587
            block_size = self.block_size
            assert block_size > 0, "block_size has not been initialized."
588
589
            if self.uses_mrope:
                # all dimensions of positions are the same
590
                block_numbers = clamped_positions[0] // block_size
591
            else:
592
                block_numbers = clamped_positions // block_size
593
            block_ids = common_attn_metadata.block_table_tensor.gather(
594
595
                dim=1, index=block_numbers.view(-1, 1)
            )
596
            block_ids = block_ids.view(-1)
597
598
            if self.uses_mrope:
                common_attn_metadata.slot_mapping = (
599
                    block_ids * block_size + clamped_positions[0] % block_size
600
                )
601
602
            else:
                common_attn_metadata.slot_mapping = (
603
                    block_ids * block_size + clamped_positions % block_size
604
                )
605
606
607
            # Mask out the slot mappings that exceed the max model length.
            # Otherwise, the KV cache will be inadvertently updated with the
            # padding tokens.
608
            common_attn_metadata.slot_mapping.masked_fill_(
609
610
                exceeds_max_model_len, PADDING_SLOT_ID
            )
611
612

            # Rebuild attention metadata
613
614
615
616
617
618
619
            for attn_group in self.draft_attn_groups:
                attn_metadata = attn_group.get_metadata_builder().build_for_drafting(
                    common_attn_metadata=common_attn_metadata,
                    draft_index=token_index + 1,
                )
                for layer_name in attn_group.layer_names:
                    per_layer_attn_metadata[layer_name] = attn_metadata
620

621
622
            # copy inputs to buffer for cudagraph
            self.input_ids[:batch_size] = input_ids
623
            self._set_positions(batch_size, clamped_positions)
624
            self.hidden_states[:batch_size] = hidden_states
625
            if self.supports_mm_inputs:
626
                self.inputs_embeds[:batch_size] = self.model.embed_input_ids(input_ids)
627

628
                input_ids = None
629
                inputs_embeds = self.inputs_embeds[:input_batch_size]
630
631
            else:
                input_ids = self.input_ids[:input_batch_size]
632
                inputs_embeds = None
633

634
            # Run the model.
635
636
637
638
639
640
641
642
            model_kwargs = {
                "input_ids": input_ids,
                "positions": self._get_positions(input_batch_size),
                "inputs_embeds": inputs_embeds,
            }
            if self.pass_hidden_states_to_model:
                model_kwargs["hidden_states"] = self.hidden_states[:input_batch_size]

643
            with set_forward_context(
644
645
646
                per_layer_attn_metadata,
                self.vllm_config,
                num_tokens=input_batch_size,
Rémi Delacourt's avatar
Rémi Delacourt committed
647
                num_tokens_across_dp=batch_size_across_dp,
648
                cudagraph_runtime_mode=cudagraph_runtime_mode,
649
650
651
                slot_mapping=self._get_slot_mapping(
                    input_batch_size, common_attn_metadata.slot_mapping
                ),
652
            ):
653
654
                ret_hidden_states = self.model(**model_kwargs)
                if not self.model_returns_tuple():
655
656
657
658
                    last_hidden_states = ret_hidden_states
                    hidden_states = ret_hidden_states
                else:
                    last_hidden_states, hidden_states = ret_hidden_states
659

660
            hidden_states = hidden_states[:batch_size]
661
            draft_token_ids = self._greedy_sample(last_hidden_states[:batch_size])
662
663
664
665
            draft_token_ids_list.append(draft_token_ids)

        # [batch_size, num_speculative_tokens]
        draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
666
        return draft_token_ids
667

668
669
670
671
672
    def set_inputs_first_pass(
        self,
        target_token_ids: torch.Tensor,
        next_token_ids: torch.Tensor,
        target_positions: torch.Tensor,
673
674
        target_hidden_states: torch.Tensor,
        token_indices_to_sample: torch.Tensor | None,
675
676
677
        cad: CommonAttentionMetadata,
        num_rejected_tokens_gpu: torch.Tensor | None,
    ) -> tuple[int, torch.Tensor, CommonAttentionMetadata]:
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
        if not self.needs_extra_input_slots:
            # Default EAGLE pathway: no reshaping of input tensors needed.
            # Simply rotate the input ids and leave the positions unchanged,
            # Inserting the next token ids at the last slot in each request.
            if token_indices_to_sample is None:
                token_indices_to_sample = cad.query_start_loc[1:] - 1

            num_tokens = target_token_ids.shape[0]
            # Shift the input ids by one token.
            # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
            self.input_ids[: num_tokens - 1] = target_token_ids[1:]
            # Replace the last token with the next token.
            # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
            self.input_ids[token_indices_to_sample] = next_token_ids

            # copy inputs to buffer for cudagraph
            if self.uses_xdrope_dim > 0 and self.draft_uses_xdrope_dim == 0:
                target_positions = target_positions[0]
            self._set_positions(num_tokens, target_positions)

            self.hidden_states[:num_tokens] = target_hidden_states

            return num_tokens, token_indices_to_sample, cad
        else:
            assert self.is_rejected_token_mask is not None
            assert self.is_masked_token_mask is not None
            # 1.
            # Call a custom triton kernel to copy input_ids and positions
            # into the correct slots in the preallocated buffers self.input_ids,
            # self.positions.
            batch_size = cad.batch_size()
            # Since we might have to copy a lot of data for prefills, we select the
            # block size based on the max query length and limit to max 256 slots/block.
            max_num_tokens_per_request = (
                cad.max_query_len + self.net_num_new_slots_per_request
            )
            BLOCK_SIZE_TOKENS = min(
                256, triton.next_power_of_2(max_num_tokens_per_request)
            )
            num_blocks = (
                max_num_tokens_per_request + BLOCK_SIZE_TOKENS - 1
            ) // BLOCK_SIZE_TOKENS
            total_num_input_tokens = target_token_ids.shape[0]
            total_num_output_tokens = total_num_input_tokens + (
                self.net_num_new_slots_per_request * batch_size
            )
724

725
726
727
728
729
            token_indices_to_sample = torch.empty(
                batch_size * self.extra_slots_per_request,
                dtype=torch.int32,
                device=self.device,
            )
730

731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
            # Destination indices to write target_hidden_states into drafting buffer.
            out_hidden_state_mapping = torch.empty(
                total_num_input_tokens, dtype=torch.int32, device=self.device
            )

            # Kernel grid: one program per request (row)
            grid = (batch_size, num_blocks)
            query_start_loc = cad.query_start_loc
            query_end_loc = cad.query_start_loc[1:] - 1
            if num_rejected_tokens_gpu is not None:
                query_end_loc = query_end_loc - num_rejected_tokens_gpu
            copy_and_expand_eagle_inputs_kernel[grid](
                # (Padded) Inputs from the target model
                target_token_ids_ptr=target_token_ids,
                target_positions_ptr=target_positions,
                next_token_ids_ptr=next_token_ids,  # sampled tokens, one per request
                # Outputs to the drafting buffers
                out_input_ids_ptr=self.input_ids,
                out_positions_ptr=self.positions,  # Doesn't support mrope for now
                out_is_rejected_token_mask_ptr=self.is_rejected_token_mask,
                out_is_masked_token_mask_ptr=self.is_masked_token_mask,
                out_new_token_indices_ptr=token_indices_to_sample,
                out_hidden_state_mapping_ptr=out_hidden_state_mapping,
                # Input metadata
                query_start_loc_ptr=query_start_loc,
                query_end_loc_ptr=query_end_loc,
                padding_token_id=0,
                parallel_drafting_token_id=self.parallel_drafting_token_id,
                # Sizing info
                # Note that we can deduce batch_size for free from the grid size
                total_input_tokens=total_num_input_tokens,
                num_padding_slots_per_request=self.extra_slots_per_request,
                shift_input_ids=self.pass_hidden_states_to_model,
                BLOCK_SIZE_TOKENS=BLOCK_SIZE_TOKENS,
            )
            if self.pass_hidden_states_to_model:
                assert self.parallel_drafting_hidden_state_tensor is not None
                self.hidden_states[out_hidden_state_mapping] = target_hidden_states
                # Use torch.where to avoid DtoH sync from boolean indexing
                mask = self.is_masked_token_mask[:total_num_output_tokens]
                torch.where(
                    mask.unsqueeze(1),
                    self.parallel_drafting_hidden_state_tensor,
                    self.hidden_states[:total_num_output_tokens],
                    out=self.hidden_states[:total_num_output_tokens],
                )

            # 2.
            # Recompute the slot mapping based on the new positions and
            # rejection mask.
781
            assert self.block_size > 0, "block_size has not been initialized."
782
783
784
785
786
787
            new_slot_mapping = compute_new_slot_mapping(
                cad=cad,
                new_positions=self.positions[:total_num_output_tokens],
                is_rejected_token_mask=self.is_rejected_token_mask[
                    :total_num_output_tokens
                ],
788
                block_size=self.block_size,
789
790
791
792
793
794
795
796
797
798
799
                num_new_tokens=self.net_num_new_slots_per_request,
                max_model_len=self.max_model_len,
            )

            # 3. Update the common attention metadata with the new (meta)data
            new_cad = extend_all_queries_by_N(
                cad,
                N=self.net_num_new_slots_per_request,
                arange=self.arange,
                new_slot_mapping=new_slot_mapping,
            )
800

801
            return total_num_output_tokens, token_indices_to_sample, new_cad
802
803
804
805

    def model_returns_tuple(self) -> bool:
        return self.method not in ("mtp", "draft_model")

806
    def prepare_next_token_ids_cpu(
807
        self,
808
        sampled_token_ids: list[list[int]],
809
810
811
812
        requests: dict[str, CachedRequestState],
        gpu_input_batch: InputBatch,
        num_scheduled_tokens: dict[str, int],
    ) -> torch.Tensor:
813
814
815
816
817
818
819
820
821
822
        """
        This function is used to prepare the inputs for speculative decoding.
        It calculates the next token ids for each request based on the sampled
        token ids from the CPU. If a request has no sampled token ids (e.g.,
        during the initial decoding steps), it falls back to using the request
        state to get the next token id.
        """
        req_ids = gpu_input_batch.req_ids
        next_token_ids: list[int] = []
        for i, token_ids in enumerate(sampled_token_ids):
823
            if token_ids:
824
825
826
827
828
829
830
                # Common case.
                next_token_id = token_ids[-1]
            else:
                # Partial prefill (rare case).
                # Get the next token id from the request state.
                req_id = req_ids[i]
                req_state = requests[req_id]
831
                seq_len = req_state.num_computed_tokens + num_scheduled_tokens[req_id]
832
833
                next_token_id = req_state.get_token_id(seq_len)
            next_token_ids.append(next_token_id)
834
        next_token_ids = torch.tensor(
835
836
            next_token_ids, dtype=torch.int32, device=self.input_ids.device
        )
837
        return next_token_ids
838

839
840
841
842
843
844
    def prepare_next_token_ids_padded(
        self,
        common_attn_metadata: CommonAttentionMetadata,
        sampled_token_ids: torch.Tensor,
        requests: dict[str, CachedRequestState],
        gpu_input_batch: InputBatch,
845
        discard_request_mask: torch.Tensor,
846
    ) -> tuple[torch.Tensor, torch.Tensor]:
847
848
849
850
        """
        This function is used to prepare the inputs for speculative decoding.
        It calculates the next token ids and the number of valid sampled tokens
        for each request, considering the "discarded" requests whose next token
851
852
        is not sampled and comes from `request.get_token_id()` instead. This is denoted
        the "backup" token id. It also counts rejected tokens via `sampled_token_ids`.
853
854
855
        """
        # Precompute get_token_id for when there is no valid next token
        num_reqs = gpu_input_batch.num_reqs
856
857
858
859
860
861
        self.backup_next_token_ids.np[:num_reqs] = np.array(
            [
                requests[gpu_input_batch.req_ids[i]].get_token_id(
                    common_attn_metadata.seq_lens_cpu[i].item()
                )
                for i in range(num_reqs)
862
863
            ],
            dtype=np.int32,
864
        )
865
        self.backup_next_token_ids.copy_to_gpu(num_reqs)
866
        backup_tokens_gpu = self.backup_next_token_ids.gpu
867

868
869
        batch_size, num_tokens = sampled_token_ids.shape
        device = sampled_token_ids.device
870

871
872
        assert discard_request_mask.dtype == torch.bool
        assert backup_tokens_gpu.dtype == torch.int32
873

874
875
        next_token_ids = torch.empty(batch_size, dtype=torch.int32, device=device)
        valid_sampled_tokens_count = next_token_ids.new_empty(batch_size)
876

877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
        # Kernel grid: one program per request (row)
        grid = (batch_size,)

        # Find the next power of 2 for block sizes
        BLOCK_SIZE_TOKENS = triton.next_power_of_2(num_tokens)
        eagle_prepare_next_token_padded_kernel[grid](
            sampled_token_ids,
            discard_request_mask,
            backup_tokens_gpu,
            next_token_ids,
            valid_sampled_tokens_count,
            gpu_input_batch.vocab_size,
            num_tokens,
            batch_size,
            sampled_token_ids.stride(0),
            BLOCK_SIZE_TOKENS=BLOCK_SIZE_TOKENS,
893
        )
894
895
896

        return next_token_ids, valid_sampled_tokens_count

897
898
899
900
901
    def prepare_inputs_padded(
        self,
        common_attn_metadata: CommonAttentionMetadata,
        spec_decode_metadata: SpecDecodeMetadata,
        valid_sampled_tokens_count: torch.Tensor,
902
    ) -> tuple[CommonAttentionMetadata, torch.Tensor, torch.Tensor]:
903
904
905
906
907
908
        """
        This function is used to prepare the inputs for speculative decoding
        It updates the common_attn_metadata for speculative decoding,
        but does not consider the rejected tokens. Instead, all tokens
        are included as inputs to the speculator, with the rejected tokens
        used as padding and filtered out later by `token_indices_to_sample`.
909
        No blocking CPU operations should be introduced in this function.
910
        """
911
912
913
914
915
        num_reqs = common_attn_metadata.num_reqs
        device = valid_sampled_tokens_count.device

        token_indices_to_sample = torch.empty(
            (num_reqs,), dtype=torch.int32, device=device
916
        )
917
918
919
        num_rejected_tokens_gpu = torch.empty(
            (num_reqs,), dtype=torch.int32, device=device
        )
920

921
922
923
924
925
926
        grid = (num_reqs,)
        eagle_prepare_inputs_padded_kernel[grid](
            spec_decode_metadata.cu_num_draft_tokens,
            valid_sampled_tokens_count,
            common_attn_metadata.query_start_loc,
            token_indices_to_sample,
927
            num_rejected_tokens_gpu,
928
            num_reqs,
929
        )
930
931

        query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
932
        new_query_len_per_req = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
933
934
935
936
937
938
939

        total_num_tokens = query_start_loc_cpu[-1].item()

        spec_common_attn_metadata = CommonAttentionMetadata(
            query_start_loc=common_attn_metadata.query_start_loc,
            seq_lens=common_attn_metadata.seq_lens,
            query_start_loc_cpu=query_start_loc_cpu,
940
941
            _seq_lens_cpu=common_attn_metadata._seq_lens_cpu,
            _num_computed_tokens_cpu=common_attn_metadata._num_computed_tokens_cpu,
942
943
944
945
946
            num_reqs=common_attn_metadata.num_reqs,
            num_actual_tokens=total_num_tokens,
            max_query_len=new_query_len_per_req.max().item(),
            max_seq_len=common_attn_metadata.seq_lens_cpu.max().item(),
            block_table_tensor=common_attn_metadata.block_table_tensor,
947
            slot_mapping=common_attn_metadata.slot_mapping[:total_num_tokens],
948
            causal=True,
949
            dcp_local_seq_lens=common_attn_metadata.dcp_local_seq_lens,
950
951
        )

952
953
954
955
956
        return (
            spec_common_attn_metadata,
            token_indices_to_sample,
            num_rejected_tokens_gpu,
        )
957

958
959
960
961
962
963
964
965
966
967
    def propose_tree(
        self,
        batch_size: int,
        # [num_tokens, vocab_size]
        logits: torch.Tensor,
        # [num_tokens]
        positions: torch.Tensor,
        # [num_tokens, hidden_size]
        hidden_states: torch.Tensor,
        common_attn_metadata: CommonAttentionMetadata,
968
969
970
        slot_mappings: dict[str, torch.Tensor]
        | list[dict[str, torch.Tensor]]
        | None = None,
971
    ) -> list[torch.Tensor]:
972
        tree_attn_metadata_builder = self.draft_attn_groups[0].get_metadata_builder()
973
        assert isinstance(tree_attn_metadata_builder, TreeAttentionMetadataBuilder)
974

975
        total_num_drafts = self.cu_drafts_per_level[0]
976
977
        level_num_drafts = total_num_drafts
        # Sample a draft token for each child at the tree root level.
978
        num_children = self.child_drafts_per_level[0]
979
980
981
        if num_children == 1:
            draft_token_ids = logits.argmax(dim=-1).view(batch_size, -1)
        else:
982
983
984
            draft_token_ids = torch.topk(logits, num_children, dim=-1).indices.view(
                batch_size, -1
            )
985
986
987
988
        draft_token_ids_list = [draft_token_ids]
        draft_hidden_states = hidden_states.view(batch_size, 1, -1)

        # Initialize empty tensors for concatenation with the level outputs.
989
990
991
992
993
994
995
996
997
        tree_input_ids = torch.empty(
            0, device=self.input_ids.device, dtype=self.input_ids.dtype
        )
        tree_positions = torch.empty(
            0, device=self.positions.device, dtype=self.positions.dtype
        )
        tree_hidden_states = torch.empty(
            0, device=self.hidden_states.device, dtype=self.hidden_states.dtype
        )
998
999
        # Precompute the draft token positions.
        flattened_draft_positions = (
1000
1001
            positions.view(batch_size, -1) + self.tree_draft_pos_offsets[:batch_size, :]
        )
1002
        tree_depth = len(self.cu_drafts_per_level)
1003
        for level in range(tree_depth - 1):
1004
1005
            # Get draft positions for RoPE.
            draft_positions = positions + (level + 1)
1006
            exceeds_max_model_len = (positions + total_num_drafts) >= self.max_model_len
1007
1008
            # Mask out the position ids that exceed the max model length.
            # Otherwise, we may get out-of-range error in RoPE.
1009
            draft_positions = torch.where(
1010
1011
1012
                exceeds_max_model_len,
                0,
                draft_positions,
1013
1014
            ).view(batch_size, -1)

1015
1016
            if level_num_drafts > 1:
                # Repeat the positions for each draft at this level.
1017
                draft_positions = draft_positions.repeat_interleave(
1018
1019
                    level_num_drafts, dim=1
                )
1020
1021
1022
1023

            if num_children > 1:
                # Repeat draft hidden states for each child.
                draft_hidden_states = draft_hidden_states.repeat_interleave(
1024
1025
                    num_children, dim=1
                )
1026
1027

            # Concatenate the draft tokens, positions, and hidden states.
1028
1029
            tree_input_ids = torch.cat([tree_input_ids, draft_token_ids], dim=1)
            tree_positions = torch.cat([tree_positions, draft_positions], dim=1)
1030
            tree_hidden_states = torch.cat(
1031
1032
                [tree_hidden_states, draft_hidden_states], dim=1
            )
1033
1034
1035

            # Build new attention metadata for the next level of drafts.
            # This is necessary to support tree attention.
1036
            query_len = total_num_drafts
1037
1038
            common_attn_metadata = replace(
                common_attn_metadata,
1039
                query_start_loc=query_len * self.arange[: batch_size + 1],
1040
1041
1042
1043
1044
                seq_lens=common_attn_metadata.seq_lens + level_num_drafts,
                num_actual_tokens=batch_size * query_len,
                max_query_len=query_len,
            )
            attn_metadata = tree_attn_metadata_builder.build_for_drafting(
1045
                common_attn_metadata=common_attn_metadata, draft_index=level + 1
1046
1047
            )

1048
            # Apply new attention metadata to all draft layers.
1049
            per_layer_attn_metadata = {}
1050
1051
1052
            for attn_group in self.draft_attn_groups:
                for layer_name in attn_group.layer_names:
                    per_layer_attn_metadata[layer_name] = attn_metadata
1053
1054

            # Consider max model length.
1055
1056
1057
            attn_metadata.max_seq_len = min(
                attn_metadata.max_seq_len, self.max_model_len
            )
1058
1059
1060
1061
1062
            # For the requests that exceed the max model length, we set the
            # sequence length to 1 to minimize their overheads in attention.
            attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1)

            # Compute the slot mapping.
1063
            block_size = tree_attn_metadata_builder.kv_cache_spec.block_size
1064
            query_positions = flattened_draft_positions[:, level : level + query_len]
1065
            block_numbers = query_positions // block_size
1066
            block_ids = attn_metadata.block_table.gather(dim=1, index=block_numbers)
1067
            slot_mapping = block_ids * block_size + query_positions % block_size
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
            # Mask out the slot mappings that exceed the max model length.
            # Otherwise, the KV cache will be inadvertently updated with the
            # padding tokens.
            slot_mapping[exceeds_max_model_len] = PADDING_SLOT_ID
            attn_metadata.slot_mapping = slot_mapping.view(-1)

            # Copy inputs to buffer for cudagraph.
            num_tokens = attn_metadata.num_actual_tokens
            input_ids = tree_input_ids.view(-1)
            self.input_ids[:num_tokens] = input_ids
            self.positions[:num_tokens] = tree_positions.view(-1)
1079
            self.hidden_states[:num_tokens] = tree_hidden_states.view(num_tokens, -1)
1080

1081
1082
1083
1084
            cudagraph_runtime_mode, batch_desc = self.cudagraph_dispatcher.dispatch(
                num_tokens
            )
            num_input_tokens = batch_desc.num_tokens
1085
            # Run the model.
1086
            with set_forward_context(
1087
1088
1089
1090
                per_layer_attn_metadata,
                self.vllm_config,
                num_tokens=num_input_tokens,
                cudagraph_runtime_mode=cudagraph_runtime_mode,
1091
1092
1093
                slot_mapping=self._get_slot_mapping(
                    num_input_tokens, attn_metadata.slot_mapping
                ),
1094
            ):
1095
1096
1097
1098
1099
1100
1101
1102
1103
                last_hidden_states, hidden_states = self.model(
                    input_ids=self.input_ids[:num_input_tokens],
                    positions=self.positions[:num_input_tokens],
                    hidden_states=self.hidden_states[:num_input_tokens],
                    inputs_embeds=None,
                )

            # Get the output hidden states for the draft tokens.
            draft_hidden_states = hidden_states[:num_tokens].view(
1104
1105
                batch_size, query_len, -1
            )[:, -level_num_drafts:]
1106
            draft_last_hidden_states = last_hidden_states[:num_tokens].view(
1107
1108
                batch_size, query_len, -1
            )[:, -level_num_drafts:]
1109
1110
1111

            # Get the output logits for the draft tokens.
            logits = self.model.compute_logits(
1112
1113
                draft_last_hidden_states.reshape(batch_size * level_num_drafts, -1)
            )
1114
1115
1116
1117
1118
1119

            # Sample a draft token for each child at the next tree level.
            num_children = self.child_drafts_per_level[level + 1]
            if num_children == 1:
                draft_token_ids = logits.argmax(dim=-1).view(batch_size, -1)
            else:
1120
1121
1122
                draft_token_ids = torch.topk(logits, num_children, dim=-1).indices.view(
                    batch_size, -1
                )
1123
1124
1125
            draft_token_ids_list.append(draft_token_ids)

            # Update the # drafts counters for the next tree level.
1126
            level_num_drafts = self.cu_drafts_per_level[level + 1] - total_num_drafts
1127
1128
1129
            total_num_drafts = self.cu_drafts_per_level[level + 1]
        return draft_token_ids_list

1130
    def prepare_inputs(
1131
1132
        self,
        common_attn_metadata: CommonAttentionMetadata,
1133
1134
        sampled_token_ids: list[list[int]],
        num_draft_tokens: list[int],
1135
1136
    ) -> tuple[CommonAttentionMetadata, torch.Tensor]:
        """
1137
        This function is used to prepare the inputs for speculative decoding.
1138
1139
1140
1141
1142
1143
        It updates to the common_attn_metadata to account for the rejected
        tokens (and newly sampled tokens). It also returns the token indices
        of the tokens that should be fed to the speculator.
        """
        # E.g.
        #  common_attn_metadata.query_start_loc{_cpu}:
1144
        #       [0, q1, q1 + q2, q1 + q2 + q3]
1145
1146
1147
1148
1149
1150
        #  common_attn_metadata.seq_lens{_cpu}: [s1, s2, s3]
        #  num_rejected_tokens: [n1, n2, n3]
        # This function computes the intermediate values:
        #  num_tokens_per_req: [q1 - n1, q2 - n2, q3 - n3]
        # And returns:
        #  common_attn_metadata.query_start_loc{_cpu}:
1151
        #       [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
1152
        #  common_attn_metadata.seq_lens{_cpu}:
1153
        #       [s1 - n1 + 1, s2 - n2 + 1, s3 - n3 + 1]
1154
        #  token_indices: [0, 1, ..., q1 - n1 - 1,
1155
1156
        #                 q1, q1 + 1, ..., q1 + q2 - n2 - 1,
        #                 q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1]
1157

1158
1159
1160
1161
        num_rejected_tokens = [
            n + 1 - len(sampled_token_ids[i]) if n > 0 else 0
            for i, n in enumerate(num_draft_tokens)
        ]
1162
        num_rejected_tokens = torch.tensor(num_rejected_tokens, dtype=torch.int32)
1163

1164
1165
        device = common_attn_metadata.query_start_loc.device
        query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
1166
        new_seq_lens_cpu = common_attn_metadata.seq_lens_cpu - num_rejected_tokens
1167
1168

        # [0, q1, q1 + q2, q1 + q2 + q3] -> [q1, q2, q3]
1169
        new_query_len_per_req = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
1170
1171
1172
1173
1174
1175
1176
1177
        # [q1, q2, q3] -> [q1 - n1, q2 - n2, q3 - n3]
        new_num_tokens_per_req = new_query_len_per_req - num_rejected_tokens
        new_num_tokens_per_req_np = new_num_tokens_per_req.numpy()

        # [q1 - n1, q2 - n2, q3 - n3] ->
        # [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
        new_query_start_loc_cpu = torch.zeros(
            query_start_loc_cpu.shape,
1178
            dtype=torch.int32,
1179
1180
            pin_memory=is_pin_memory_available(),
        )
1181
1182
1183
1184
1185
1186
1187
1188
1189
        new_query_start_loc_np = new_query_start_loc_cpu.numpy()
        np.cumsum(new_num_tokens_per_req_np, out=new_query_start_loc_np[1:])

        total_num_tokens = new_query_start_loc_np[-1]
        # Example assuming num_tokens_per_req_np = [2, 4, 3]
        # this implies that `new_query_start_locs` is:
        # [0, 2, 6, 9] ->
        # [0, 0, 2, 2, 2, 2, 6, 6, 6]
        #  _r1_  ____r2____  ___r3__
1190
1191
1192
        new_query_start_locs_expanded = np.repeat(
            new_query_start_loc_np[:-1], new_num_tokens_per_req_np
        )
1193
1194
1195
        # [0, 1, 2, 3, 4, 5, 6, 7, 8] ->
        # [0, 1, 0, 1, 2, 3, 0, 1, 2]
        #  _r1_  ____r2____  ___r3__
1196
        token_offsets = (
1197
1198
            self.token_arange_np[:total_num_tokens] - new_query_start_locs_expanded
        )
1199
1200
1201
1202
1203
1204

        # Expand starting positions to match token pattern
        # [0, q1, q1 + q2] ->
        # [0, 0, q1, q1, q1, q1, q1 + q2, q1 + q2, q1 + q2]
        #  _r1_  _____r2_______  ___________r3____________
        old_query_start_locs_expanded = np.repeat(
1205
1206
            query_start_loc_cpu[:-1].numpy(), new_num_tokens_per_req_np
        )
1207
        # Final token indices are:
1208
1209
1210
        # [0, 1,                                // req 1
        #  q1 + 0, q1 + 1, q1 + 2, q1 + 3,       // req 2
        #  q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3
1211
        token_indices_np = token_offsets + old_query_start_locs_expanded
1212
        token_indices = torch.from_numpy(token_indices_np).to(device, non_blocking=True)
1213
1214

        spec_common_attn_metadata = CommonAttentionMetadata(
1215
            query_start_loc=new_query_start_loc_cpu.to(device, non_blocking=True),
1216
1217
            seq_lens=new_seq_lens_cpu.to(device, non_blocking=True),
            query_start_loc_cpu=new_query_start_loc_cpu,
1218
1219
            _seq_lens_cpu=new_seq_lens_cpu,
            _num_computed_tokens_cpu=common_attn_metadata._num_computed_tokens_cpu,
1220
1221
1222
            num_reqs=common_attn_metadata.num_reqs,
            num_actual_tokens=total_num_tokens,
            max_query_len=new_query_len_per_req.max().item(),
1223
            max_seq_len=new_seq_lens_cpu.max().item(),
1224
1225
            block_table_tensor=common_attn_metadata.block_table_tensor,
            slot_mapping=common_attn_metadata.slot_mapping[token_indices],
1226
            causal=True,
1227
            dcp_local_seq_lens=common_attn_metadata.dcp_local_seq_lens,
1228
        )
1229
1230

        return spec_common_attn_metadata, token_indices
1231

1232
    def get_model_name(self, model: nn.Module) -> str:
1233
        if hasattr(model, "module"):  # multi-GPU
1234
1235
1236
            model = model.module
        return model.__class__.__name__

1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
    def _get_model(self) -> nn.Module:
        """
        Default method to call get_model(). Can be overridden by subclasses which
        need to customize model loading.
        """
        from vllm.compilation.backends import set_model_tag

        with set_model_tag("eagle_head"):
            model = get_model(
                vllm_config=self.vllm_config,
                model_config=self.speculative_config.draft_model_config,
1248
                load_config=self.speculative_config.draft_load_config,
1249
1250
1251
            )
        return model

1252
    def load_model(self, target_model: nn.Module) -> None:
1253
        target_attn_layer_names = set(
1254
1255
1256
1257
            get_layers_from_vllm_config(
                self.vllm_config,
                AttentionLayerBase,  # type: ignore[type-abstract]
            ).keys()
1258
        )
1259

1260
        self.model = self._get_model()
1261

1262
1263
1264
1265
        # Find draft layers (attention layers added by draft model)
        all_attn_layers = get_layers_from_vllm_config(
            self.vllm_config,
            AttentionLayerBase,  # type: ignore[type-abstract]
1266
        )
1267
1268
        self._draft_attn_layer_names = (
            set(all_attn_layers.keys()) - target_attn_layer_names
1269
        )
1270

1271
        if self.supports_mm_inputs:
1272
1273
1274
            # Even if the target model is multimodal, we can also use
            # text-only draft models
            try:
1275
                dummy_input_ids = torch.tensor([[1]], device=self.input_ids.device)
1276
                self.model.embed_input_ids(dummy_input_ids, multimodal_embeddings=None)
1277
1278
1279
            except (NotImplementedError, AttributeError, TypeError):
                logger.warning(
                    "Draft model does not support multimodal inputs, "
1280
1281
                    "falling back to text-only mode"
                )
1282
                self.supports_mm_inputs = False
1283

1284
1285
        if supports_multimodal(target_model):
            # handle multimodality
1286
            assert hasattr(target_model, "config")
1287
1288
1289
            if self.get_model_name(target_model) in [
                "Qwen2_5_VLForConditionalGeneration",
                "Qwen3VLForConditionalGeneration",
1290
                "Qwen3VLMoeForConditionalGeneration",
1291
                "HunYuanVLForConditionalGeneration",
1292
                "GlmOcrForConditionalGeneration",
1293
1294
                "Qwen3_5ForConditionalGeneration",
                "Qwen3_5MoeForConditionalGeneration",
1295
            ]:
1296
                self.model.config.image_token_index = target_model.config.image_token_id
1297
1298
1299
1300
            elif self.get_model_name(target_model) == "PixtralForConditionalGeneration":
                self.model.config.image_token_index = (
                    target_model.config.vision_config.image_token_id
                )
1301
1302
            else:
                self.model.config.image_token_index = (
1303
1304
                    target_model.config.image_token_index
                )
1305
1306
1307
            target_language_model = cast(
                SupportsMultiModal, target_model
            ).get_language_model()
1308
1309
        else:
            target_language_model = target_model
1310

1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
        self._maybe_share_embeddings(target_language_model)
        self._maybe_share_lm_head(target_language_model)

        if self.parallel_drafting and self.pass_hidden_states_to_model:
            assert self.parallel_drafting_hidden_state_tensor is not None
            self.parallel_drafting_hidden_state_tensor.copy_(
                self.model.combine_hidden_states(
                    self.model.mask_hidden.view(3 * self.hidden_size)
                )
                if self.eagle3_use_aux_hidden_state
                else self.model.mask_hidden.view(self.hidden_size)
            )

    def _maybe_share_embeddings(self, target_language_model: nn.Module) -> None:
        """
        Some draft models may not have their own embedding layers, and some may
        have a duplicate copy of the target model's embedding layers. In these cases,
        we share the target model's embedding layers with the draft model to save
        memory.
        """
1331
        if get_pp_group().world_size == 1:
1332
1333
1334
1335
1336
1337
1338
            inner_model = getattr(target_language_model, "model", None)
            if inner_model is None:
                raise AttributeError("Target model does not have 'model' attribute")
            if hasattr(inner_model, "embed_tokens"):
                target_embed_tokens = inner_model.embed_tokens
            elif hasattr(inner_model, "embedding"):
                target_embed_tokens = inner_model.embedding
1339
1340
            else:
                raise AttributeError(
1341
1342
                    "Target model does not have 'embed_tokens' or 'embedding' attribute"
                )
1343

1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
            share_embeddings = False
            if hasattr(self.model, "has_own_embed_tokens"):
                # EAGLE model
                if not self.model.has_own_embed_tokens:
                    share_embeddings = True
                    logger.info(
                        "Detected EAGLE model without its own embed_tokens in the"
                        " checkpoint. Sharing target model embedding weights with the"
                        " draft model."
                    )
                elif (
                    isinstance(target_embed_tokens.weight, torch.Tensor)
                    and isinstance(self.model.model.embed_tokens.weight, torch.Tensor)
1357
1358
1359
                    # TODO: Offload to CPU for comparison to avoid extra GPU memory
                    # usage in CI testing environments with limited GPU memory
                    and torch.equal(
1360
1361
                        target_embed_tokens.weight.cpu(),
                        self.model.model.embed_tokens.weight.cpu(),
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
                    )
                ):
                    share_embeddings = True
                    logger.info(
                        "Detected EAGLE model with embed_tokens identical to the target"
                        " model. Sharing target model embedding weights with the draft"
                        " model."
                    )
                else:
                    logger.info(
                        "Detected EAGLE model with distinct embed_tokens weights. "
                        "Keeping separate embedding weights from the target model."
                    )
1375
            else:
1376
1377
                # MTP model
                share_embeddings = True
1378
                logger.info(
1379
1380
                    "Detected MTP model. "
                    "Sharing target model embedding weights with the draft model."
1381
                )
1382
1383
1384
1385
1386

            if share_embeddings:
                if hasattr(self.model.model, "embed_tokens"):
                    del self.model.model.embed_tokens
                self.model.model.embed_tokens = target_embed_tokens
1387
        else:
1388
            logger.info(
1389
                "The draft model's vocab embedding will be loaded separately"
1390
1391
                " from the target model."
            )
1392

1393
1394
1395
1396
1397
1398
    def _maybe_share_lm_head(self, target_language_model: nn.Module) -> None:
        """
        Some draft models may not have their own LM head, and some may have a
        duplicate copy of the target model's LM head. In these cases, we share
        the target model's LM head with the draft model to save memory.
        """
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
        share_lm_head = False
        if hasattr(self.model, "has_own_lm_head"):
            # EAGLE model
            if not self.model.has_own_lm_head:
                share_lm_head = True
                logger.info(
                    "Detected EAGLE model without its own lm_head in the checkpoint. "
                    "Sharing target model lm_head weights with the draft model."
                )
            elif (
                hasattr(target_language_model, "lm_head")
                and isinstance(target_language_model.lm_head.weight, torch.Tensor)
                and isinstance(self.model.lm_head.weight, torch.Tensor)
1412
1413
                # TODO: Offload to CPU for comparison to avoid extra GPU memory
                # usage in CI testing environments with limited GPU memory
1414
                and torch.equal(
1415
1416
                    target_language_model.lm_head.weight.cpu(),
                    self.model.lm_head.weight.cpu(),
1417
                )
1418
            ):
1419
                share_lm_head = True
1420
                logger.info(
1421
1422
                    "Detected EAGLE model with lm_head identical to the target model. "
                    "Sharing target model lm_head weights with the draft model."
1423
                )
1424
1425
            else:
                logger.info(
1426
1427
                    "Detected EAGLE model with distinct lm_head weights. "
                    "Keeping separate lm_head weights from the target model."
1428
                )
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
        else:
            # MTP model
            share_lm_head = True
            logger.info(
                "Detected MTP model. "
                "Sharing target model lm_head weights with the draft model."
            )

        if share_lm_head and hasattr(target_language_model, "lm_head"):
            if hasattr(self.model, "lm_head"):
                del self.model.lm_head
            self.model.lm_head = target_language_model.lm_head
1441

1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
            # MTP models call compute_logits via shared_head.head (a
            # ParallelLMHead inside each MTP layer), not self.model.lm_head.
            # If the checkpoint omits a copy of the lm_head weights at the
            # MTP layer path, shared_head.head stays uninitialised and
            # produces NaN logits. Always share it explicitly.
            inner = getattr(self.model, "model", None)
            layers = getattr(inner, "layers", None) if inner else None
            if layers is not None:
                items = layers.values() if isinstance(layers, nn.ModuleDict) else layers
                for layer in items:
                    sh = getattr(layer, "shared_head", None)
                    if sh is not None and hasattr(sh, "head"):
                        del sh.head
                        sh.head = target_language_model.lm_head
                        logger.info(
                            "Shared target model lm_head with MTP shared_head.head."
                        )

1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
        if self.use_local_argmax_reduction:
            if not hasattr(self.model, "get_top_tokens"):
                raise ValueError(
                    "use_local_argmax_reduction is enabled but draft model "
                    f"{self.model.__class__.__name__} does not implement "
                    "get_top_tokens()."
                )
            # Warn if draft model has vocab remapping, which forces fallback
            # to the full-logits path (negating the optimization).
            if (
                hasattr(self.model, "draft_id_to_target_id")
                and self.model.draft_id_to_target_id is not None
            ):
                logger.warning(
                    "use_local_argmax_reduction is enabled but draft model "
                    "uses draft_id_to_target_id vocab remapping. The "
                    "optimization will be bypassed (falling back to full "
                    "logits gather + argmax)."
                )
            else:
                logger.info(
                    "Using local argmax reduction for draft token generation "
                    "(communication: O(2*tp_size) vs O(vocab_size))."
                )

1485
1486
1487
1488
    @torch.inference_mode()
    def dummy_run(
        self,
        num_tokens: int,
1489
1490
        use_cudagraphs: bool = True,
        is_graph_capturing: bool = False,
1491
        slot_mappings: dict[str, torch.Tensor] | None = None,
1492
    ) -> None:
Rémi Delacourt's avatar
Rémi Delacourt committed
1493
1494
1495
1496
        # FIXME: when using tree-based specdec, adjust number of forward-passes
        # according to the depth of the tree.
        for fwd_idx in range(
            self.num_speculative_tokens if not is_graph_capturing else 1
1497
        ):
Rémi Delacourt's avatar
Rémi Delacourt committed
1498
            if fwd_idx <= 1:
1499
1500
1501
                cudagraph_runtime_mode, num_input_tokens, num_tokens_across_dp = (
                    self._determine_batch_execution_and_padding(
                        num_tokens, use_cudagraphs=use_cudagraphs
1502
                    )
1503
                )
1504

1505
1506
            # Make sure to use EAGLE's own buffer during cudagraph capture.
            if (
1507
                self._draft_attn_layer_names
1508
                and slot_mappings is not None
1509
                and next(iter(self._draft_attn_layer_names)) in slot_mappings
1510
1511
1512
1513
1514
            ):
                slot_mapping_dict = self._get_slot_mapping(num_input_tokens)
            else:
                slot_mapping_dict = slot_mappings or {}

Rémi Delacourt's avatar
Rémi Delacourt committed
1515
1516
1517
1518
1519
            with set_forward_context(
                None,
                self.vllm_config,
                num_tokens=num_input_tokens,
                num_tokens_across_dp=num_tokens_across_dp,
1520
                cudagraph_runtime_mode=cudagraph_runtime_mode,
1521
                slot_mapping=slot_mapping_dict,
Rémi Delacourt's avatar
Rémi Delacourt committed
1522
1523
1524
1525
1526
1527
1528
1529
            ):
                if self.supports_mm_inputs:
                    input_ids = None
                    inputs_embeds = self.inputs_embeds[:num_input_tokens]
                else:
                    input_ids = self.input_ids[:num_input_tokens]
                    inputs_embeds = None

1530
                kwargs = dict(
Rémi Delacourt's avatar
Rémi Delacourt committed
1531
1532
1533
1534
                    input_ids=input_ids,
                    positions=self._get_positions(num_input_tokens),
                    inputs_embeds=inputs_embeds,
                )
1535
1536
1537
                if self.pass_hidden_states_to_model:
                    kwargs["hidden_states"] = self.hidden_states[:num_input_tokens]
                self.model(**kwargs)
1538

1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
    def _get_eagle3_use_aux_hidden_state_from_config(self) -> bool:
        """
        Some eagle3 heads (e.g., nvidia/gpt-oss-120b-Eagle3-v2) do not use auxiliary
        hidden states and directly uses the last layer output just like eagle1.
        They might indicate this by setting "use_aux_hidden_state" to False
        inside the "eagle_config" dict of their hf_config.
        """
        if self.method != "eagle3":
            return False
        # Assume that eagle3 heads use aux hidden states by default
        use_aux_hidden_state = True
        eagle_config = getattr(self.draft_model_config.hf_config, "eagle_config", None)
        if eagle_config is not None:
            use_aux_hidden_state = eagle_config.get("use_aux_hidden_state", True)
        return use_aux_hidden_state

1555
    def validate_same_kv_cache_group(self, kv_cache_config: KVCacheConfig) -> None:
1556
        """
1557
1558
        Validate that all drafting layers belong to the same KVCacheGroup.
        Need this assumption to ensure all drafting layers can use the
1559
1560
1561
1562
1563
1564
1565
        same AttentionMetadata.
        May extend to multiple AttentionMetadata in the future.
        """
        kv_cache_groups: dict[str, int] = {}
        for id, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups):
            for layer_name in kv_cache_group.layer_names:
                kv_cache_groups[layer_name] = id
1566
1567
1568
1569
1570
        assert (
            len(
                set(
                    [
                        kv_cache_groups[layer_name]
1571
                        for layer_name in self._draft_attn_layer_names
1572
1573
1574
1575
                    ]
                )
            )
            == 1
1576
        ), "All drafting layers should belong to the same kv cache group"
1577

1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
    def initialize_attn_backend(
        self,
        kv_cache_config: KVCacheConfig,
        kernel_block_sizes: list[int] | None = None,
    ) -> None:
        """
        Initialize AttentionGroups for draft layers using kv_cache_config.
        Called from the model runner's initialize_metadata_builders.
        """
        all_attn_layers = get_layers_from_vllm_config(
            self.vllm_config,
            AttentionLayerBase,  # type: ignore[type-abstract]
        )

        # Find which kv_cache_group the draft layers belong to
        self.validate_same_kv_cache_group(kv_cache_config)
        kv_cache_spec = None
        for gid, group in enumerate(kv_cache_config.kv_cache_groups):
            if self._draft_attn_layer_names & set(group.layer_names):
                self.kv_cache_gid = gid
                kv_cache_spec = group.kv_cache_spec
                break

        attention_groups: dict[tuple[str, str], AttentionGroup] = {}
        if kv_cache_spec is not None:
            for layer_name in self._draft_attn_layer_names:
                attn_backend = all_attn_layers[layer_name].get_attn_backend()
                backend_key = attn_backend.full_cls_name()
                if backend_key not in attention_groups:
                    layer_kv_cache_spec = kv_cache_spec
                    if isinstance(layer_kv_cache_spec, UniformTypeKVCacheSpecs):
                        layer_kv_cache_spec = layer_kv_cache_spec.kv_cache_specs[
                            layer_name
                        ]

                    kernel_block_size = (
                        kernel_block_sizes[self.kv_cache_gid]
                        if kernel_block_sizes is not None
                        and self.kv_cache_gid < len(kernel_block_sizes)
                        else None
                    )
                    attn_group = AttentionGroup(
                        backend=attn_backend,
                        layer_names=[layer_name],
                        kv_cache_spec=layer_kv_cache_spec,
                        kv_cache_group_id=self.kv_cache_gid,
                    )
                    attn_group.create_metadata_builders(
                        self.vllm_config,
                        self.device,
                        kernel_block_size=kernel_block_size,
                    )
                    attention_groups[backend_key] = attn_group
                else:
                    attention_groups[backend_key].layer_names.append(layer_name)

        self.draft_attn_groups = list(attention_groups.values())
1635
1636
1637
1638
        self.block_size = (
            self.draft_attn_groups[0].get_metadata_builder().kv_cache_spec.block_size
        )
        logger.debug("Using block size %d for drafting layers", self.block_size)
1639

1640
    def _determine_batch_execution_and_padding(
Rémi Delacourt's avatar
Rémi Delacourt committed
1641
        self,
1642
1643
1644
1645
1646
1647
        num_tokens: int,
        use_cudagraphs: bool = True,
    ) -> tuple[CUDAGraphMode, int, torch.Tensor | None]:
        cudagraph_mode, batch_desc = self.cudagraph_dispatcher.dispatch(
            num_tokens,
            valid_modes=({CUDAGraphMode.NONE} if not use_cudagraphs else None),
Rémi Delacourt's avatar
Rémi Delacourt committed
1648
        )
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
        num_tokens_padded = batch_desc.num_tokens

        # Extra coordination when running data-parallel since we need to
        # coordinate across ranks
        # TODO(Flechman): support DBO ubatching
        should_ubatch, num_tokens_across_dp = False, None
        if self.vllm_config.parallel_config.data_parallel_size > 1:
            should_ubatch, num_tokens_across_dp, synced_cudagraph_mode = (
                coordinate_batch_across_dp(
                    num_tokens_unpadded=num_tokens,
                    parallel_config=self.vllm_config.parallel_config,
                    allow_microbatching=False,
                    num_tokens_padded=num_tokens_padded,
                    cudagraph_mode=cudagraph_mode.value,
                )
            )
            assert not should_ubatch, "DBO ubatching not implemented for EAGLE"

            # Extract DP-synced values
            if num_tokens_across_dp is not None:
                dp_rank = self.dp_rank
                num_tokens_padded = int(num_tokens_across_dp[dp_rank].item())
                # Re-dispatch with DP padding so we have the correct
                # batch_descriptor
                cudagraph_mode, batch_desc = self.cudagraph_dispatcher.dispatch(
                    num_tokens_padded,
                    valid_modes={CUDAGraphMode(synced_cudagraph_mode)},
                )
                # Assert to make sure the agreed upon token count is correct
                # otherwise num_tokens_across_dp will no-longer be valid
                assert batch_desc.num_tokens == num_tokens_padded
                num_tokens_across_dp[dp_rank] = num_tokens_padded
Rémi Delacourt's avatar
Rémi Delacourt committed
1681

1682
        return cudagraph_mode, num_tokens_padded, num_tokens_across_dp
Rémi Delacourt's avatar
Rémi Delacourt committed
1683

1684

1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
class EagleProposer(SpecDecodeBaseProposer):
    def __init__(
        self,
        vllm_config: VllmConfig,
        device: torch.device,
        runner=None,
    ):
        super().__init__(
            vllm_config,
            device,
            pass_hidden_states_to_model=True,
            runner=runner,
        )


1700
1701
1702
1703
# NOTE(woosuk): Currently, the below code is not used and we always use argmax
# to sample the draft tokens. We will use this after we find a way to manage
# the draft prob tensor.
# Refer to https://github.com/vllm-project/vllm/pull/16899 for the details.
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
# FIXME(woosuk): The logic here is duplicated with the main sampling code.
# We should refactor this to reuse the same sampling implementation.
def compute_probs_and_sample_next_token(
    logits: torch.Tensor,
    sampling_metadata: SamplingMetadata,
) -> tuple[torch.Tensor, torch.Tensor]:
    if sampling_metadata.all_greedy:
        # For greedy requests, draft_probs is not used in rejection sampling.
        # Therefore, we can just return the logits.
        probs = logits
        next_token_ids = logits.argmax(dim=-1)
        return next_token_ids, probs

1717
1718
1719
1720
1721
1722
1723
1724
1725
    assert sampling_metadata.temperature is not None

    # Use epsilon comparison to detect greedy sampling (temperature ~ 0.0)
    # consistent with sampler.py's _SAMPLING_EPS threshold
    temperature = sampling_metadata.temperature
    # Avoid division by zero if there are greedy requests.
    if not sampling_metadata.all_random:
        is_greedy = temperature < _SAMPLING_EPS
        temperature = torch.where(is_greedy, 1.0, temperature)
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
    logits.div_(temperature.view(-1, 1))
    probs = logits.softmax(dim=-1, dtype=torch.float32)

    # NOTE(woosuk): Currently, we ignore most of the sampling parameters in
    # generating the draft tokens. We only use the temperature. While this
    # could degrade the acceptance rate, it does not affect the distribution
    # of the generated tokens after rejection sampling.

    # TODO(woosuk): Consider seeds.
    q = torch.empty_like(probs)
    q.exponential_()
1737
1738
1739
    # NOTE(woosuk): We shouldn't use `probs.div_(q)` because the draft_probs
    # will be used later for rejection sampling.
    next_token_ids = probs.div(q).argmax(dim=-1).view(-1)
1740
1741
    if not sampling_metadata.all_random:
        greedy_token_ids = probs.argmax(dim=-1)
1742
        next_token_ids = torch.where(is_greedy, greedy_token_ids, next_token_ids)
1743
    return next_token_ids, probs