eagle.py 59.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
import ast
from dataclasses import replace
5
from importlib.util import find_spec
6

7
import numpy as np
8
9
10
import torch
import torch.nn as nn

11
12
13
14
15
from vllm.config import (
    CUDAGraphMode,
    VllmConfig,
    get_layers_from_vllm_config,
)
16
from vllm.distributed.parallel_state import get_pp_group
17
from vllm.forward_context import set_forward_context
18
from vllm.logger import init_logger
19
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
20
from vllm.model_executor.model_loader import get_model
21
from vllm.model_executor.models import supports_multimodal
22
from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache
23
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
24
from vllm.multimodal import MULTIMODAL_REGISTRY
25
from vllm.platforms import current_platform
26
from vllm.triton_utils import triton
27
from vllm.utils.platform_utils import is_pin_memory_available
28
29
30
31
from vllm.v1.attention.backend import (
    AttentionMetadataBuilder,
    CommonAttentionMetadata,
)
32
from vllm.v1.attention.backends.registry import AttentionBackendEnum
33
34
35
36
from vllm.v1.attention.backends.tree_attn import (
    TreeAttentionMetadata,
    TreeAttentionMetadataBuilder,
)
37
from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata
38
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
39
from vllm.v1.kv_cache_interface import KVCacheConfig
40
from vllm.v1.sample.metadata import SamplingMetadata
41
from vllm.v1.sample.sampler import _SAMPLING_EPS
42
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
43
44
45
46
from vllm.v1.spec_decode.utils import (
    eagle_prepare_inputs_padded_kernel,
    eagle_prepare_next_token_padded_kernel,
)
47
from vllm.v1.utils import CpuGpuBuffer
Rémi Delacourt's avatar
Rémi Delacourt committed
48
from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
49
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
50

51
52
logger = init_logger(__name__)

53
54
PADDING_SLOT_ID = -1

55

56
class SpecDecodeBaseProposer:
57
58
59
60
    def __init__(
        self,
        vllm_config: VllmConfig,
        device: torch.device,
61
        pass_hidden_states_to_model: bool,
Jiayi Yao's avatar
Jiayi Yao committed
62
        runner=None,
63
64
    ):
        self.vllm_config = vllm_config
65
        self.speculative_config = vllm_config.speculative_config
66
        assert self.speculative_config is not None
67
68
        self.draft_model_config = self.speculative_config.draft_model_config
        self.method = self.speculative_config.method
69
        self.pass_hidden_states_to_model = pass_hidden_states_to_model
70

Jiayi Yao's avatar
Jiayi Yao committed
71
        self.runner = runner
72
        self.device = device
73
        self.dtype = vllm_config.model_config.dtype
74
        self.max_model_len = vllm_config.model_config.max_model_len
Rémi Delacourt's avatar
Rémi Delacourt committed
75
        self.dp_rank = vllm_config.parallel_config.data_parallel_rank
76
        self.num_speculative_tokens = self.speculative_config.num_speculative_tokens
77
78
79
80
81
        # The drafter can get longer sequences than the target model.
        max_batch_size = vllm_config.scheduler_config.max_num_seqs
        self.max_num_tokens = (
            vllm_config.scheduler_config.max_num_batched_tokens + max_batch_size
        )
82
        self.token_arange_np = np.arange(self.max_num_tokens)
83
84
85
86
        # We need to get the hidden size from the draft model config because
        # the draft model's hidden size can be different from the target model's
        # hidden size (e.g., Llama 3.3 70B).
        self.hidden_size = self.draft_model_config.get_hidden_size()
87
        self.inputs_embeds_size = self.draft_model_config.get_inputs_embeds_size()
88

89
90
91
        # Multi-modal data support
        self.mm_registry = MULTIMODAL_REGISTRY
        self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
92
            vllm_config.model_config
93
        )
94

95
96
        self.attn_metadata_builder: AttentionMetadataBuilder | None = None
        self.draft_indexer_metadata_builder: AttentionMetadataBuilder | None = None
97
98
        self.attn_layer_names: list[str] = []
        self.indexer_layer_names: list[str] = []
99
100
101
        self.eagle3_use_aux_hidden_state: bool = (
            self._get_eagle3_use_aux_hidden_state_from_config()
        )
102

103
        self.compilation_config = self.vllm_config.compilation_config
104
105
106
107
108
109

        # Cudagraph dispatcher for PIECEWISE-only dispatching in eagle.
        # Keys are initialized later via initialize_cudagraph_keys() called from
        # gpu_model_runner._check_and_update_cudagraph_mode after
        # adjust_cudagraph_sizes_for_spec_decode is called.
        self.cudagraph_dispatcher = CudagraphDispatcher(self.vllm_config)
110

111
        # persistent buffers for cuda graph
112
113
114
        self.input_ids = torch.zeros(
            self.max_num_tokens, dtype=torch.int32, device=device
        )
115
116
        self.uses_mrope = self.vllm_config.model_config.uses_mrope
        if self.uses_mrope:
117
118
119
120
121
122
123
124
125
126
            # NOTE: `mrope_positions` is implemented with one additional dummy
            # position on purpose to make it non-contiguous so that it can work
            # with torch compile.
            # See detailed explanation in https://github.com/vllm-project/vllm/pull/12128#discussion_r1926431923

            # NOTE: When M-RoPE is enabled, position ids are 3D regardless of
            # the modality of inputs. For text-only inputs, each dimension has
            # identical position IDs, making M-RoPE functionally equivalent to
            # 1D-RoPE.
            # See page 5 of https://arxiv.org/abs/2409.12191
127
            self.mrope_positions = torch.zeros(
128
                (3, self.max_num_tokens + 1), dtype=torch.int64, device=device
129
            )
130
131
        else:
            # RoPE need (max_num_tokens,)
132
133
134
            self.positions = torch.zeros(
                self.max_num_tokens, dtype=torch.int64, device=device
            )
135
        self.hidden_states = torch.zeros(
136
137
            (self.max_num_tokens, self.hidden_size), dtype=self.dtype, device=device
        )
138

139
140
141
        # We need +1 here because the arange is used to set query_start_loc,
        # which has one more element than batch_size.
        max_num_slots_for_arange = max(max_batch_size + 1, self.max_num_tokens)
142
143
144
        self.arange = torch.arange(
            max_num_slots_for_arange, device=device, dtype=torch.int32
        )
145

146
        self.inputs_embeds = torch.zeros(
147
148
149
            (self.max_num_tokens, self.inputs_embeds_size),
            dtype=self.dtype,
            device=device,
150
        )
151

152
153
154
155
156
        self.backup_next_token_ids = CpuGpuBuffer(
            max_batch_size,
            dtype=torch.int32,
            pin_memory=is_pin_memory_available(),
            device=device,
157
158
            with_numpy=True,
        )
159

160
        # Determine allowed attention backends once during initialization.
161
        self.allowed_attn_types: tuple | None = None
162
        if current_platform.is_rocm():
163
164
165
166
167
168
            from vllm.v1.attention.backends.rocm_attn import RocmAttentionMetadata

            rocm_types = [
                TritonAttentionMetadata,
                RocmAttentionMetadata,
            ]
169
170
171
172
            # ROCM_AITER_FA is an optional backend
            if find_spec(
                AttentionBackendEnum.ROCM_AITER_FA.get_path(include_classname=False)
            ):
173
                from vllm.v1.attention.backends.rocm_aiter_fa import (
174
175
176
                    AiterFlashAttentionMetadata,
                )

177
                rocm_types.append(AiterFlashAttentionMetadata)
178
179

            # TRITON_MLA backend support for MLA models (e.g., DeepSeek)
180
181
182
            from vllm.model_executor.layers.attention.mla_attention import (
                MLACommonMetadata,
            )
183
184
185

            rocm_types.append(MLACommonMetadata)

186
187
188
189
190
            # FlexAttention backend support
            from vllm.v1.attention.backends.flex_attention import FlexAttentionMetadata

            rocm_types.append(FlexAttentionMetadata)

191
192
            self.allowed_attn_types = tuple(rocm_types)

193
194
        # Parse the speculative token tree.
        spec_token_tree = self.speculative_config.speculative_token_tree
195
        self.tree_choices: list[tuple[int, ...]] = ast.literal_eval(spec_token_tree)
196
197
198
199
200
201
202
203
        tree_depth = len(self.tree_choices[-1])
        # Precompute per-level properties of the tree.
        num_drafts_per_level = [0] * tree_depth
        for node in self.tree_choices:
            num_drafts_per_level[len(node) - 1] += 1
        self.cu_drafts_per_level = [num_drafts_per_level[0]]
        self.child_drafts_per_level = [num_drafts_per_level[0]]
        for level in range(1, tree_depth):
204
205
206
207
208
209
            self.cu_drafts_per_level.append(
                self.cu_drafts_per_level[-1] + num_drafts_per_level[level]
            )
            self.child_drafts_per_level.append(
                num_drafts_per_level[level] // num_drafts_per_level[level - 1]
            )
210
211
        # Precompute draft position offsets in flattened tree.
        self.tree_draft_pos_offsets = torch.arange(
212
            1, len(self.tree_choices) + 1, device=device, dtype=torch.int32
213
214
        ).repeat(max_batch_size, 1)

215
216
217
218
219
220
221
222
223
224
225
    def _get_positions(self, num_tokens: int):
        if self.uses_mrope:
            return self.mrope_positions[:, :num_tokens]
        return self.positions[:num_tokens]

    def _set_positions(self, num_tokens: int, positions: torch.Tensor):
        if self.uses_mrope:
            self.mrope_positions[:, :num_tokens] = positions
        else:
            self.positions[:num_tokens] = positions

226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
    def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode) -> None:
        """Initialize cudagraph dispatcher keys for eagle.

        Eagle only supports PIECEWISE cudagraphs (via mixed_mode).
        This should be called after adjust_cudagraph_sizes_for_spec_decode.
        """
        if (
            not self.speculative_config.enforce_eager
            and cudagraph_mode.mixed_mode()
            in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL]
        ):
            eagle_cudagraph_mode = CUDAGraphMode.PIECEWISE
        else:
            eagle_cudagraph_mode = CUDAGraphMode.NONE

        self.cudagraph_dispatcher.initialize_cudagraph_keys(eagle_cudagraph_mode)

243
244
245
246
    def propose(
        self,
        # [num_tokens]
        target_token_ids: torch.Tensor,
247
        # [num_tokens] or [3, num_tokens] when M-RoPE is enabled
248
249
250
251
252
        target_positions: torch.Tensor,
        # [num_tokens, hidden_size]
        target_hidden_states: torch.Tensor,
        # [batch_size]
        next_token_ids: torch.Tensor,
253
        last_token_indices: torch.Tensor | None,
254
        common_attn_metadata: CommonAttentionMetadata,
255
        sampling_metadata: SamplingMetadata,
256
        mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None,
257
        num_rejected_tokens_gpu: torch.Tensor | None = None,
258
    ) -> torch.Tensor:
259
        batch_size = common_attn_metadata.batch_size()
260

261
262
263
        if self.method == "eagle3":
            assert isinstance(self.model, Eagle3LlamaForCausalLM)
            target_hidden_states = self.model.combine_hidden_states(
264
265
                target_hidden_states
            )
266
            assert target_hidden_states.shape[-1] == self.hidden_size
267
268
269
270
271
272
273
274
275
276
277

        num_tokens, last_token_indices, common_attn_metadata = (
            self.set_inputs_first_pass(
                target_token_ids=target_token_ids,
                next_token_ids=next_token_ids,
                target_positions=target_positions,
                last_token_indices=last_token_indices,
                cad=common_attn_metadata,
                num_rejected_tokens_gpu=num_rejected_tokens_gpu,
            )
        )
278

279
        assert self.runner is not None
Jiayi Yao's avatar
Jiayi Yao committed
280

281
282
283
284
285
        if self.attn_metadata_builder is None:
            attn_metadata_builder = self._get_attention_metadata_builder()
        else:
            attn_metadata_builder = self.attn_metadata_builder

286
        attn_metadata = attn_metadata_builder.build_for_drafting(
287
288
            common_attn_metadata=common_attn_metadata, draft_index=0
        )
289
290
291
292
293
294
        # FIXME: support hybrid kv for draft model (remove separate indexer)
        if self.draft_indexer_metadata_builder:
            draft_indexer_metadata = (
                self.draft_indexer_metadata_builder.build_for_drafting(
                    common_attn_metadata=common_attn_metadata,
                    draft_index=0,
295
296
                )
            )
297
298
        else:
            draft_indexer_metadata = None
299
300
301
302
303
        # At this moment, we assume all eagle layers belong to the same KV
        # cache group, thus using the same attention metadata.
        per_layer_attn_metadata = {}
        for layer_name in self.attn_layer_names:
            per_layer_attn_metadata[layer_name] = attn_metadata
304

305
306
307
308
        for layer_name in self.indexer_layer_names:
            assert draft_indexer_metadata is not None
            per_layer_attn_metadata[layer_name] = draft_indexer_metadata

Rémi Delacourt's avatar
Rémi Delacourt committed
309
        num_tokens_dp_padded, num_tokens_across_dp = self._pad_batch_across_dp(
310
            num_tokens_unpadded=num_tokens, num_tokens_padded=num_tokens
Rémi Delacourt's avatar
Rémi Delacourt committed
311
312
        )

313
314
315
316
        cudagraph_runtime_mode, batch_desc = self.cudagraph_dispatcher.dispatch(
            num_tokens_dp_padded
        )
        num_input_tokens = batch_desc.num_tokens
Rémi Delacourt's avatar
Rémi Delacourt committed
317
318
319
        if num_tokens_across_dp is not None:
            num_tokens_across_dp[self.dp_rank] = num_input_tokens

320
321
322
323
        if self.pass_hidden_states_to_model:
            # target_hidden_states and self.hidden_states can have different
            # hidden dims. E.g. large target model and small draft model.
            self.hidden_states[:num_tokens] = target_hidden_states
324
325
326
327

        if self.supports_mm_inputs:
            mm_embeds, is_mm_embed = mm_embed_inputs or (None, None)

328
            self.inputs_embeds[:num_tokens] = self.model.embed_input_ids(
329
330
331
                self.input_ids[:num_tokens],
                multimodal_embeddings=mm_embeds,
                is_multimodal=is_mm_embed,
332
            )
333

334
            input_ids = None
335
            inputs_embeds = self.inputs_embeds[:num_input_tokens]
336
337
        else:
            input_ids = self.input_ids[:num_input_tokens]
338
            inputs_embeds = None
339

340
341
342
343
344
345
346
347
        model_kwargs = {
            "input_ids": input_ids,
            "positions": self._get_positions(num_input_tokens),
            "inputs_embeds": inputs_embeds,
        }
        if self.pass_hidden_states_to_model:
            model_kwargs["hidden_states"] = self.hidden_states[:num_input_tokens]

348
        with set_forward_context(
349
350
351
            per_layer_attn_metadata,
            self.vllm_config,
            num_tokens=num_input_tokens,
Rémi Delacourt's avatar
Rémi Delacourt committed
352
            num_tokens_across_dp=num_tokens_across_dp,
353
            cudagraph_runtime_mode=cudagraph_runtime_mode,
354
        ):
355
356
            ret_hidden_states = self.model(**model_kwargs)
            if not self.model_returns_tuple():
Jiayi Yao's avatar
Jiayi Yao committed
357
                last_hidden_states = ret_hidden_states
358
                hidden_states = last_hidden_states
Jiayi Yao's avatar
Jiayi Yao committed
359
360
            else:
                last_hidden_states, hidden_states = ret_hidden_states
361

362
        sample_hidden_states = last_hidden_states[last_token_indices]
363
        logits = self.model.compute_logits(sample_hidden_states)
364
365
366
367
368
369

        # Early exit if there is only one draft token to be generated.
        if self.num_speculative_tokens == 1:
            draft_token_ids = logits.argmax(dim=-1)
            return draft_token_ids.view(-1, 1)

370
        if self.uses_mrope:
371
            positions = self.positions[:, last_token_indices]
372
        else:
373
            positions = self.positions[last_token_indices]
374
375
376
377
378
379
        if self.method in (
            "deepseek_mtp",
            "ernie_mtp",
            "longcat_flash_mtp",
            "pangu_ultra_moe_mtp",
        ):
XuruiYang's avatar
XuruiYang committed
380
381
382
            hidden_states = self.hidden_states[last_token_indices]
        else:
            hidden_states = hidden_states[last_token_indices]
383
384
385

        if isinstance(attn_metadata, TreeAttentionMetadata):
            # Draft using tree attention.
386
387
388
389
390
391
392
393
394
395
            draft_token_ids_list = self.propose_tree(
                batch_size=batch_size,
                logits=logits,
                positions=positions,
                hidden_states=hidden_states,
                common_attn_metadata=common_attn_metadata,
            )
            # [batch_size, num_tree_tokens]
            return torch.cat(draft_token_ids_list, dim=1)

396
        draft_token_ids = logits.argmax(dim=-1)
397

398
399
400
        if self.allowed_attn_types is not None and not isinstance(
            attn_metadata, self.allowed_attn_types
        ):
401
402
403
404
            raise ValueError(
                f"Unsupported attention metadata type for speculative "
                "decoding with num_speculative_tokens > 1: "
                f"{type(attn_metadata)}. Supported types are: "
405
406
                f"{self.allowed_attn_types}"
            )
407

408
409
410
        # Generate the remaining draft tokens.
        draft_token_ids_list = [draft_token_ids]

Rémi Delacourt's avatar
Rémi Delacourt committed
411
        batch_size_dp_padded, batch_size_across_dp = self._pad_batch_across_dp(
412
            num_tokens_unpadded=batch_size, num_tokens_padded=batch_size
Rémi Delacourt's avatar
Rémi Delacourt committed
413
414
        )

415
416
417
418
        cudagraph_runtime_mode, batch_desc = self.cudagraph_dispatcher.dispatch(
            batch_size_dp_padded
        )
        input_batch_size = batch_desc.num_tokens
Rémi Delacourt's avatar
Rémi Delacourt committed
419
420
        if batch_size_across_dp is not None:
            batch_size_across_dp[self.dp_rank] = input_batch_size
421
422
423

        common_attn_metadata.num_actual_tokens = batch_size
        common_attn_metadata.max_query_len = 1
424
        common_attn_metadata.query_start_loc = self.arange[: batch_size + 1]
425
        common_attn_metadata.query_start_loc_cpu = torch.from_numpy(
426
427
            self.token_arange_np[: batch_size + 1]
        ).clone()
428
429
430
431
432
433
434
435
436
437
438

        # In padded drafter batch, we need to adjust the sequence lengths
        # to remove the "padding" (i.e. rejected tokens).
        # Only apply this adjustment when we have rejected tokens
        # (i.e., not the first proposal).
        if self.num_speculative_tokens > 1 and num_rejected_tokens_gpu is not None:
            common_attn_metadata.seq_lens -= num_rejected_tokens_gpu
            # Invalidate the CPU-side shadows to avoid H<>D sync.
            common_attn_metadata._seq_lens_cpu = None
            common_attn_metadata._num_computed_tokens_cpu = None

439
        for token_index in range(self.num_speculative_tokens - 1):
440
            # Update the inputs.
441
442
443
            # cast to int32 is crucial when eagle model is compiled.
            # tensor.argmax() returns int64 by default.
            input_ids = draft_token_ids_list[-1].int()
444
445
446
447
448
449
450
451
452
453
454
455
456
            if self.uses_mrope:
                positions += 1
                # NOTE(woosuk): We should handle the case where the draft model
                # generates tokens beyond the max model length.
                # Since it is complex to remove such requests from the batch,
                # we keep them in the batch but adjust the position ids
                # and slot mappings to avoid the
                # out-of-range access during the model execution.
                # The draft tokens generated with this adjustment
                # should be ignored.
                exceeds_max_model_len = positions[0] >= self.max_model_len
                # Mask out the position ids that exceed the max model length.
                # Otherwise, we may get out-of-range error in RoPE.
457
458
459
460
461
                clamped_positions = torch.where(
                    exceeds_max_model_len.unsqueeze(0),
                    torch.zeros_like(positions),
                    positions,
                )
462
463
464
            else:
                positions += 1
                exceeds_max_model_len = positions >= self.max_model_len
465
                clamped_positions = torch.where(exceeds_max_model_len, 0, positions)
466
467
468
            # For data integrity when async scheduling, we shouldn't use in place
            # operations in case they are modified in next step's `prepare_input`
            # of main model.
469
            # Increment the sequence lengths.
470
            common_attn_metadata.seq_lens += 1
471
472
            # For the requests that exceed the max model length, we set the
            # sequence length to 1 to minimize their overheads in attention.
473
            common_attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1)
474
475
476
477
478
479
            # Increment the maximum sequence length. We increment max_seq_len
            # unconditionally even though some seq_lens may have been capped above,
            # as max_seq_len serves as an upper bound for sequence lengths.
            common_attn_metadata.max_seq_len = min(
                common_attn_metadata.max_seq_len + 1, self.max_model_len
            )
480

481
482
483
484
485
486
            # Also update the CPU-side shadow; NOTE: this is hacky and should be
            # removed in when common_attn_metadata.seq_lens_cpu is deprecated.
            if common_attn_metadata._seq_lens_cpu is not None:
                common_attn_metadata._seq_lens_cpu += 1
            if common_attn_metadata._num_computed_tokens_cpu is not None:
                common_attn_metadata._num_computed_tokens_cpu += 1
487

488
            # Compute the slot mapping.
489
            block_size = attn_metadata_builder.kv_cache_spec.block_size
490
491
            if self.uses_mrope:
                # all dimensions of positions are the same
492
                block_numbers = clamped_positions[0] // block_size
493
            else:
494
                block_numbers = clamped_positions // block_size
495
            block_ids = common_attn_metadata.block_table_tensor.gather(
496
497
                dim=1, index=block_numbers.view(-1, 1)
            )
498
            block_ids = block_ids.view(-1)
499
500
            if self.uses_mrope:
                common_attn_metadata.slot_mapping = (
501
                    block_ids * block_size + clamped_positions[0] % block_size
502
                )
503
504
            else:
                common_attn_metadata.slot_mapping = (
505
                    block_ids * block_size + clamped_positions % block_size
506
                )
507
508
509
            # Mask out the slot mappings that exceed the max model length.
            # Otherwise, the KV cache will be inadvertently updated with the
            # padding tokens.
510
            common_attn_metadata.slot_mapping.masked_fill_(
511
512
                exceeds_max_model_len, PADDING_SLOT_ID
            )
513
514

            # Rebuild attention metadata
515
            attn_metadata = attn_metadata_builder.build_for_drafting(  # type: ignore
516
517
                common_attn_metadata=common_attn_metadata, draft_index=token_index + 1
            )
518
519
            for layer_name in self.attn_layer_names:
                per_layer_attn_metadata[layer_name] = attn_metadata
520

521
522
            # copy inputs to buffer for cudagraph
            self.input_ids[:batch_size] = input_ids
523
            self._set_positions(batch_size, clamped_positions)
524
            self.hidden_states[:batch_size] = hidden_states
525
            if self.supports_mm_inputs:
526
                self.inputs_embeds[:batch_size] = self.model.embed_input_ids(input_ids)
527

528
                input_ids = None
529
                inputs_embeds = self.inputs_embeds[:input_batch_size]
530
531
            else:
                input_ids = self.input_ids[:input_batch_size]
532
                inputs_embeds = None
533

534
            # Run the model.
535
536
537
538
539
540
541
542
            model_kwargs = {
                "input_ids": input_ids,
                "positions": self._get_positions(input_batch_size),
                "inputs_embeds": inputs_embeds,
            }
            if self.pass_hidden_states_to_model:
                model_kwargs["hidden_states"] = self.hidden_states[:input_batch_size]

543
            with set_forward_context(
544
545
546
                per_layer_attn_metadata,
                self.vllm_config,
                num_tokens=input_batch_size,
Rémi Delacourt's avatar
Rémi Delacourt committed
547
                num_tokens_across_dp=batch_size_across_dp,
548
                cudagraph_runtime_mode=cudagraph_runtime_mode,
549
            ):
550
551
                ret_hidden_states = self.model(**model_kwargs)
                if not self.model_returns_tuple():
552
553
554
555
                    last_hidden_states = ret_hidden_states
                    hidden_states = ret_hidden_states
                else:
                    last_hidden_states, hidden_states = ret_hidden_states
556

557
            hidden_states = hidden_states[:batch_size]
558
            logits = self.model.compute_logits(last_hidden_states[:batch_size])
559
            draft_token_ids = logits.argmax(dim=-1)
560
561
562
563
            draft_token_ids_list.append(draft_token_ids)

        # [batch_size, num_speculative_tokens]
        draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
564
        return draft_token_ids
565

566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
    def set_inputs_first_pass(
        self,
        target_token_ids: torch.Tensor,
        next_token_ids: torch.Tensor,
        target_positions: torch.Tensor,
        last_token_indices: torch.Tensor | None,
        cad: CommonAttentionMetadata,
        num_rejected_tokens_gpu: torch.Tensor | None,
    ) -> tuple[int, torch.Tensor, CommonAttentionMetadata]:
        if last_token_indices is None:
            last_token_indices = cad.query_start_loc[1:] - 1

        num_tokens = target_token_ids.shape[0]
        # Shift the input ids by one token.
        # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
        self.input_ids[: num_tokens - 1] = target_token_ids[1:]
        # Replace the last token with the next token.
        # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
        self.input_ids[last_token_indices] = next_token_ids

        # copy inputs to buffer for cudagraph
        self._set_positions(num_tokens, target_positions)

        return num_tokens, last_token_indices, cad

    def model_returns_tuple(self) -> bool:
        return self.method not in ("mtp", "draft_model")

594
    def prepare_next_token_ids_cpu(
595
        self,
596
        sampled_token_ids: list[list[int]],
597
598
599
600
        requests: dict[str, CachedRequestState],
        gpu_input_batch: InputBatch,
        num_scheduled_tokens: dict[str, int],
    ) -> torch.Tensor:
601
602
603
604
605
606
607
608
609
610
        """
        This function is used to prepare the inputs for speculative decoding.
        It calculates the next token ids for each request based on the sampled
        token ids from the CPU. If a request has no sampled token ids (e.g.,
        during the initial decoding steps), it falls back to using the request
        state to get the next token id.
        """
        req_ids = gpu_input_batch.req_ids
        next_token_ids: list[int] = []
        for i, token_ids in enumerate(sampled_token_ids):
611
            if token_ids:
612
613
614
615
616
617
618
                # Common case.
                next_token_id = token_ids[-1]
            else:
                # Partial prefill (rare case).
                # Get the next token id from the request state.
                req_id = req_ids[i]
                req_state = requests[req_id]
619
                seq_len = req_state.num_computed_tokens + num_scheduled_tokens[req_id]
620
621
                next_token_id = req_state.get_token_id(seq_len)
            next_token_ids.append(next_token_id)
622
        next_token_ids = torch.tensor(
623
624
            next_token_ids, dtype=torch.int32, device=self.input_ids.device
        )
625
        return next_token_ids
626

627
628
629
630
631
632
    def prepare_next_token_ids_padded(
        self,
        common_attn_metadata: CommonAttentionMetadata,
        sampled_token_ids: torch.Tensor,
        requests: dict[str, CachedRequestState],
        gpu_input_batch: InputBatch,
633
        discard_request_mask: torch.Tensor,
634
    ) -> tuple[torch.Tensor, torch.Tensor]:
635
636
637
638
        """
        This function is used to prepare the inputs for speculative decoding.
        It calculates the next token ids and the number of valid sampled tokens
        for each request, considering the "discarded" requests whose next token
639
640
        is not sampled and comes from `request.get_token_id()` instead. This is denoted
        the "backup" token id. It also counts rejected tokens via `sampled_token_ids`.
641
642
643
        """
        # Precompute get_token_id for when there is no valid next token
        num_reqs = gpu_input_batch.num_reqs
644
645
646
647
648
649
        self.backup_next_token_ids.np[:num_reqs] = np.array(
            [
                requests[gpu_input_batch.req_ids[i]].get_token_id(
                    common_attn_metadata.seq_lens_cpu[i].item()
                )
                for i in range(num_reqs)
650
651
            ],
            dtype=np.int32,
652
        )
653
        self.backup_next_token_ids.copy_to_gpu(num_reqs)
654
        backup_tokens_gpu = self.backup_next_token_ids.gpu
655

656
657
        batch_size, num_tokens = sampled_token_ids.shape
        device = sampled_token_ids.device
658

659
660
        assert discard_request_mask.dtype == torch.bool
        assert backup_tokens_gpu.dtype == torch.int32
661

662
663
        next_token_ids = torch.empty(batch_size, dtype=torch.int32, device=device)
        valid_sampled_tokens_count = next_token_ids.new_empty(batch_size)
664

665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
        # Kernel grid: one program per request (row)
        grid = (batch_size,)

        # Find the next power of 2 for block sizes
        BLOCK_SIZE_TOKENS = triton.next_power_of_2(num_tokens)
        eagle_prepare_next_token_padded_kernel[grid](
            sampled_token_ids,
            discard_request_mask,
            backup_tokens_gpu,
            next_token_ids,
            valid_sampled_tokens_count,
            gpu_input_batch.vocab_size,
            num_tokens,
            batch_size,
            sampled_token_ids.stride(0),
            BLOCK_SIZE_TOKENS=BLOCK_SIZE_TOKENS,
681
        )
682
683
684

        return next_token_ids, valid_sampled_tokens_count

685
686
687
688
689
    def prepare_inputs_padded(
        self,
        common_attn_metadata: CommonAttentionMetadata,
        spec_decode_metadata: SpecDecodeMetadata,
        valid_sampled_tokens_count: torch.Tensor,
690
    ) -> tuple[CommonAttentionMetadata, torch.Tensor, torch.Tensor]:
691
692
693
694
695
696
        """
        This function is used to prepare the inputs for speculative decoding
        It updates the common_attn_metadata for speculative decoding,
        but does not consider the rejected tokens. Instead, all tokens
        are included as inputs to the speculator, with the rejected tokens
        used as padding and filtered out later by `token_indices_to_sample`.
697
        No blocking CPU operations should be introduced in this function.
698
        """
699
700
701
702
703
        num_reqs = common_attn_metadata.num_reqs
        device = valid_sampled_tokens_count.device

        token_indices_to_sample = torch.empty(
            (num_reqs,), dtype=torch.int32, device=device
704
        )
705
706
707
        num_rejected_tokens_gpu = torch.empty(
            (num_reqs,), dtype=torch.int32, device=device
        )
708

709
710
711
712
713
714
        grid = (num_reqs,)
        eagle_prepare_inputs_padded_kernel[grid](
            spec_decode_metadata.cu_num_draft_tokens,
            valid_sampled_tokens_count,
            common_attn_metadata.query_start_loc,
            token_indices_to_sample,
715
            num_rejected_tokens_gpu,
716
            num_reqs,
717
        )
718
719

        query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
720
        new_query_len_per_req = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
721
722
723
724
725
726
727

        total_num_tokens = query_start_loc_cpu[-1].item()

        spec_common_attn_metadata = CommonAttentionMetadata(
            query_start_loc=common_attn_metadata.query_start_loc,
            seq_lens=common_attn_metadata.seq_lens,
            query_start_loc_cpu=query_start_loc_cpu,
728
729
            _seq_lens_cpu=common_attn_metadata._seq_lens_cpu,
            _num_computed_tokens_cpu=common_attn_metadata._num_computed_tokens_cpu,
730
731
732
733
734
            num_reqs=common_attn_metadata.num_reqs,
            num_actual_tokens=total_num_tokens,
            max_query_len=new_query_len_per_req.max().item(),
            max_seq_len=common_attn_metadata.seq_lens_cpu.max().item(),
            block_table_tensor=common_attn_metadata.block_table_tensor,
735
            slot_mapping=common_attn_metadata.slot_mapping[:total_num_tokens],
736
            causal=True,
737
            dcp_local_seq_lens=common_attn_metadata.dcp_local_seq_lens,
738
739
        )

740
741
742
743
744
        return (
            spec_common_attn_metadata,
            token_indices_to_sample,
            num_rejected_tokens_gpu,
        )
745

746
747
748
749
750
751
752
753
754
755
756
    def propose_tree(
        self,
        batch_size: int,
        # [num_tokens, vocab_size]
        logits: torch.Tensor,
        # [num_tokens]
        positions: torch.Tensor,
        # [num_tokens, hidden_size]
        hidden_states: torch.Tensor,
        common_attn_metadata: CommonAttentionMetadata,
    ) -> list[torch.Tensor]:
757
758
759
760
        tree_attn_metadata_builder = self.runner.attn_groups[0][
            0
        ].get_metadata_builder()
        assert isinstance(tree_attn_metadata_builder, TreeAttentionMetadataBuilder)
761

762
        total_num_drafts = self.cu_drafts_per_level[0]
763
764
        level_num_drafts = total_num_drafts
        # Sample a draft token for each child at the tree root level.
765
        num_children = self.child_drafts_per_level[0]
766
767
768
        if num_children == 1:
            draft_token_ids = logits.argmax(dim=-1).view(batch_size, -1)
        else:
769
770
771
            draft_token_ids = torch.topk(logits, num_children, dim=-1).indices.view(
                batch_size, -1
            )
772
773
774
775
        draft_token_ids_list = [draft_token_ids]
        draft_hidden_states = hidden_states.view(batch_size, 1, -1)

        # Initialize empty tensors for concatenation with the level outputs.
776
777
778
779
780
781
782
783
784
        tree_input_ids = torch.empty(
            0, device=self.input_ids.device, dtype=self.input_ids.dtype
        )
        tree_positions = torch.empty(
            0, device=self.positions.device, dtype=self.positions.dtype
        )
        tree_hidden_states = torch.empty(
            0, device=self.hidden_states.device, dtype=self.hidden_states.dtype
        )
785
786
        # Precompute the draft token positions.
        flattened_draft_positions = (
787
788
            positions.view(batch_size, -1) + self.tree_draft_pos_offsets[:batch_size, :]
        )
789
        tree_depth = len(self.cu_drafts_per_level)
790
        for level in range(tree_depth - 1):
791
792
            # Get draft positions for RoPE.
            draft_positions = positions + (level + 1)
793
            exceeds_max_model_len = (positions + total_num_drafts) >= self.max_model_len
794
795
            # Mask out the position ids that exceed the max model length.
            # Otherwise, we may get out-of-range error in RoPE.
796
            draft_positions = torch.where(
797
798
799
                exceeds_max_model_len,
                0,
                draft_positions,
800
801
            ).view(batch_size, -1)

802
803
            if level_num_drafts > 1:
                # Repeat the positions for each draft at this level.
804
                draft_positions = draft_positions.repeat_interleave(
805
806
                    level_num_drafts, dim=1
                )
807
808
809
810

            if num_children > 1:
                # Repeat draft hidden states for each child.
                draft_hidden_states = draft_hidden_states.repeat_interleave(
811
812
                    num_children, dim=1
                )
813
814

            # Concatenate the draft tokens, positions, and hidden states.
815
816
            tree_input_ids = torch.cat([tree_input_ids, draft_token_ids], dim=1)
            tree_positions = torch.cat([tree_positions, draft_positions], dim=1)
817
            tree_hidden_states = torch.cat(
818
819
                [tree_hidden_states, draft_hidden_states], dim=1
            )
820
821
822

            # Build new attention metadata for the next level of drafts.
            # This is necessary to support tree attention.
823
            query_len = total_num_drafts
824
825
            common_attn_metadata = replace(
                common_attn_metadata,
826
                query_start_loc=query_len * self.arange[: batch_size + 1],
827
828
829
830
831
                seq_lens=common_attn_metadata.seq_lens + level_num_drafts,
                num_actual_tokens=batch_size * query_len,
                max_query_len=query_len,
            )
            attn_metadata = tree_attn_metadata_builder.build_for_drafting(
832
                common_attn_metadata=common_attn_metadata, draft_index=level + 1
833
834
835
836
837
838
839
840
            )

            # Apply new attention metadata to all layers.
            per_layer_attn_metadata = {}
            for layer_name in self.attn_layer_names:
                per_layer_attn_metadata[layer_name] = attn_metadata

            # Consider max model length.
841
842
843
            attn_metadata.max_seq_len = min(
                attn_metadata.max_seq_len, self.max_model_len
            )
844
845
846
847
848
            # For the requests that exceed the max model length, we set the
            # sequence length to 1 to minimize their overheads in attention.
            attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1)

            # Compute the slot mapping.
849
            block_size = tree_attn_metadata_builder.kv_cache_spec.block_size
850
            query_positions = flattened_draft_positions[:, level : level + query_len]
851
            block_numbers = query_positions // block_size
852
            block_ids = attn_metadata.block_table.gather(dim=1, index=block_numbers)
853
            slot_mapping = block_ids * block_size + query_positions % block_size
854
855
856
857
858
859
860
861
862
863
864
            # Mask out the slot mappings that exceed the max model length.
            # Otherwise, the KV cache will be inadvertently updated with the
            # padding tokens.
            slot_mapping[exceeds_max_model_len] = PADDING_SLOT_ID
            attn_metadata.slot_mapping = slot_mapping.view(-1)

            # Copy inputs to buffer for cudagraph.
            num_tokens = attn_metadata.num_actual_tokens
            input_ids = tree_input_ids.view(-1)
            self.input_ids[:num_tokens] = input_ids
            self.positions[:num_tokens] = tree_positions.view(-1)
865
            self.hidden_states[:num_tokens] = tree_hidden_states.view(num_tokens, -1)
866

867
868
869
870
            cudagraph_runtime_mode, batch_desc = self.cudagraph_dispatcher.dispatch(
                num_tokens
            )
            num_input_tokens = batch_desc.num_tokens
871
            # Run the model.
872
            with set_forward_context(
873
874
875
876
                per_layer_attn_metadata,
                self.vllm_config,
                num_tokens=num_input_tokens,
                cudagraph_runtime_mode=cudagraph_runtime_mode,
877
            ):
878
879
880
881
882
883
884
885
886
                last_hidden_states, hidden_states = self.model(
                    input_ids=self.input_ids[:num_input_tokens],
                    positions=self.positions[:num_input_tokens],
                    hidden_states=self.hidden_states[:num_input_tokens],
                    inputs_embeds=None,
                )

            # Get the output hidden states for the draft tokens.
            draft_hidden_states = hidden_states[:num_tokens].view(
887
888
                batch_size, query_len, -1
            )[:, -level_num_drafts:]
889
            draft_last_hidden_states = last_hidden_states[:num_tokens].view(
890
891
                batch_size, query_len, -1
            )[:, -level_num_drafts:]
892
893
894

            # Get the output logits for the draft tokens.
            logits = self.model.compute_logits(
895
896
                draft_last_hidden_states.reshape(batch_size * level_num_drafts, -1)
            )
897
898
899
900
901
902

            # Sample a draft token for each child at the next tree level.
            num_children = self.child_drafts_per_level[level + 1]
            if num_children == 1:
                draft_token_ids = logits.argmax(dim=-1).view(batch_size, -1)
            else:
903
904
905
                draft_token_ids = torch.topk(logits, num_children, dim=-1).indices.view(
                    batch_size, -1
                )
906
907
908
            draft_token_ids_list.append(draft_token_ids)

            # Update the # drafts counters for the next tree level.
909
            level_num_drafts = self.cu_drafts_per_level[level + 1] - total_num_drafts
910
911
912
            total_num_drafts = self.cu_drafts_per_level[level + 1]
        return draft_token_ids_list

913
    def prepare_inputs(
914
915
        self,
        common_attn_metadata: CommonAttentionMetadata,
916
917
        sampled_token_ids: list[list[int]],
        num_draft_tokens: list[int],
918
919
    ) -> tuple[CommonAttentionMetadata, torch.Tensor]:
        """
920
        This function is used to prepare the inputs for speculative decoding.
921
922
923
924
925
926
        It updates to the common_attn_metadata to account for the rejected
        tokens (and newly sampled tokens). It also returns the token indices
        of the tokens that should be fed to the speculator.
        """
        # E.g.
        #  common_attn_metadata.query_start_loc{_cpu}:
927
        #       [0, q1, q1 + q2, q1 + q2 + q3]
928
929
930
931
932
933
        #  common_attn_metadata.seq_lens{_cpu}: [s1, s2, s3]
        #  num_rejected_tokens: [n1, n2, n3]
        # This function computes the intermediate values:
        #  num_tokens_per_req: [q1 - n1, q2 - n2, q3 - n3]
        # And returns:
        #  common_attn_metadata.query_start_loc{_cpu}:
934
        #       [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
935
        #  common_attn_metadata.seq_lens{_cpu}:
936
        #       [s1 - n1 + 1, s2 - n2 + 1, s3 - n3 + 1]
937
        #  token_indices: [0, 1, ..., q1 - n1 - 1,
938
939
        #                 q1, q1 + 1, ..., q1 + q2 - n2 - 1,
        #                 q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1]
940

941
942
943
944
        num_rejected_tokens = [
            n + 1 - len(sampled_token_ids[i]) if n > 0 else 0
            for i, n in enumerate(num_draft_tokens)
        ]
945
        num_rejected_tokens = torch.tensor(num_rejected_tokens, dtype=torch.int32)
946

947
948
        device = common_attn_metadata.query_start_loc.device
        query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
949
        new_seq_lens_cpu = common_attn_metadata.seq_lens_cpu - num_rejected_tokens
950
951

        # [0, q1, q1 + q2, q1 + q2 + q3] -> [q1, q2, q3]
952
        new_query_len_per_req = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
953
954
955
956
957
958
959
960
        # [q1, q2, q3] -> [q1 - n1, q2 - n2, q3 - n3]
        new_num_tokens_per_req = new_query_len_per_req - num_rejected_tokens
        new_num_tokens_per_req_np = new_num_tokens_per_req.numpy()

        # [q1 - n1, q2 - n2, q3 - n3] ->
        # [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
        new_query_start_loc_cpu = torch.zeros(
            query_start_loc_cpu.shape,
961
            dtype=torch.int32,
962
963
            pin_memory=is_pin_memory_available(),
        )
964
965
966
967
968
969
970
971
972
        new_query_start_loc_np = new_query_start_loc_cpu.numpy()
        np.cumsum(new_num_tokens_per_req_np, out=new_query_start_loc_np[1:])

        total_num_tokens = new_query_start_loc_np[-1]
        # Example assuming num_tokens_per_req_np = [2, 4, 3]
        # this implies that `new_query_start_locs` is:
        # [0, 2, 6, 9] ->
        # [0, 0, 2, 2, 2, 2, 6, 6, 6]
        #  _r1_  ____r2____  ___r3__
973
974
975
        new_query_start_locs_expanded = np.repeat(
            new_query_start_loc_np[:-1], new_num_tokens_per_req_np
        )
976
977
978
        # [0, 1, 2, 3, 4, 5, 6, 7, 8] ->
        # [0, 1, 0, 1, 2, 3, 0, 1, 2]
        #  _r1_  ____r2____  ___r3__
979
980
981
        token_offests = (
            self.token_arange_np[:total_num_tokens] - new_query_start_locs_expanded
        )
982
983
984
985
986
987

        # Expand starting positions to match token pattern
        # [0, q1, q1 + q2] ->
        # [0, 0, q1, q1, q1, q1, q1 + q2, q1 + q2, q1 + q2]
        #  _r1_  _____r2_______  ___________r3____________
        old_query_start_locs_expanded = np.repeat(
988
989
            query_start_loc_cpu[:-1].numpy(), new_num_tokens_per_req_np
        )
990
        # Final token indices are:
991
992
993
        # [0, 1,                                // req 1
        #  q1 + 0, q1 + 1, q1 + 2, q1 + 3,       // req 2
        #  q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3
994
        token_indices_np = token_offests + old_query_start_locs_expanded
995
        token_indices = torch.from_numpy(token_indices_np).to(device, non_blocking=True)
996
997

        spec_common_attn_metadata = CommonAttentionMetadata(
998
            query_start_loc=new_query_start_loc_cpu.to(device, non_blocking=True),
999
1000
            seq_lens=new_seq_lens_cpu.to(device, non_blocking=True),
            query_start_loc_cpu=new_query_start_loc_cpu,
1001
1002
            _seq_lens_cpu=new_seq_lens_cpu,
            _num_computed_tokens_cpu=common_attn_metadata._num_computed_tokens_cpu,
1003
1004
1005
            num_reqs=common_attn_metadata.num_reqs,
            num_actual_tokens=total_num_tokens,
            max_query_len=new_query_len_per_req.max().item(),
1006
            max_seq_len=new_seq_lens_cpu.max().item(),
1007
1008
            block_table_tensor=common_attn_metadata.block_table_tensor,
            slot_mapping=common_attn_metadata.slot_mapping[token_indices],
1009
            causal=True,
1010
            dcp_local_seq_lens=common_attn_metadata.dcp_local_seq_lens,
1011
        )
1012
1013

        return spec_common_attn_metadata, token_indices
1014

1015
    def get_model_name(self, model: nn.Module) -> str:
1016
        if hasattr(model, "module"):  # multi-GPU
1017
1018
1019
            model = model.module
        return model.__class__.__name__

1020
    def load_model(self, target_model: nn.Module) -> None:
1021
        draft_model_config = self.vllm_config.speculative_config.draft_model_config
1022
        target_attn_layer_names = set(
1023
            get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase).keys()
1024
        )
1025
1026
        # FIXME: support hybrid kv for draft model
        target_indexer_layer_names = set(
1027
1028
1029
1030
            get_layers_from_vllm_config(
                self.vllm_config, DeepseekV32IndexerCache
            ).keys()
        )
1031

1032
        from vllm.compilation.backends import set_model_tag
1033

1034
        with set_model_tag("eagle_head"):
1035
1036
1037
            self.model = get_model(
                vllm_config=self.vllm_config, model_config=draft_model_config
            )
1038

1039
        draft_attn_layer_names = (
1040
            get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase).keys()
1041
1042
1043
1044
1045
1046
            - target_attn_layer_names
        )
        indexer_layers = get_layers_from_vllm_config(
            self.vllm_config, DeepseekV32IndexerCache
        )
        draft_indexer_layer_names = indexer_layers.keys() - target_indexer_layer_names
1047
        self.attn_layer_names = list(draft_attn_layer_names - draft_indexer_layer_names)
1048
1049
1050
1051
1052
        self.indexer_layer_names = list(draft_indexer_layer_names)

        if self.indexer_layer_names:
            first_layer = self.indexer_layer_names[0]
            self.draft_indexer_metadata_builder = (
1053
1054
1055
                indexer_layers[first_layer]
                .get_attn_backend()
                .get_builder_cls()(
1056
                    indexer_layers[first_layer].get_kv_cache_spec(self.vllm_config),
1057
1058
1059
                    self.indexer_layer_names,
                    self.vllm_config,
                    self.device,
1060
1061
                )
            )
1062
1063
        else:
            self.draft_indexer_metadata_builder = None
1064

1065
        if self.supports_mm_inputs:
1066
1067
1068
            # Even if the target model is multimodal, we can also use
            # text-only draft models
            try:
1069
                dummy_input_ids = torch.tensor([[1]], device=self.input_ids.device)
1070
                self.model.embed_input_ids(dummy_input_ids, multimodal_embeddings=None)
1071
1072
1073
            except (NotImplementedError, AttributeError, TypeError):
                logger.warning(
                    "Draft model does not support multimodal inputs, "
1074
1075
                    "falling back to text-only mode"
                )
1076
                self.supports_mm_inputs = False
1077

1078
1079
        if supports_multimodal(target_model):
            # handle multimodality
1080
1081
1082
1083
            if self.get_model_name(target_model) in [
                "Qwen2_5_VLForConditionalGeneration",
                "Qwen3VLForConditionalGeneration",
            ]:
1084
                self.model.config.image_token_index = target_model.config.image_token_id
1085
1086
1087
1088
            elif self.get_model_name(target_model) == "PixtralForConditionalGeneration":
                self.model.config.image_token_index = (
                    target_model.config.vision_config.image_token_id
                )
1089
1090
            else:
                self.model.config.image_token_index = (
1091
1092
                    target_model.config.image_token_index
                )
1093
1094
1095
            target_language_model = target_model.get_language_model()
        else:
            target_language_model = target_model
1096

1097
        # share embed_tokens with the target model if needed
1098
        if get_pp_group().world_size == 1:
1099
            if hasattr(target_language_model.model, "embed_tokens"):
1100
                target_embed_tokens = target_language_model.model.embed_tokens
1101
            elif hasattr(target_language_model.model, "embedding"):
1102
1103
1104
                target_embed_tokens = target_language_model.model.embedding
            else:
                raise AttributeError(
1105
1106
                    "Target model does not have 'embed_tokens' or 'embedding' attribute"
                )
1107

1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
            share_embeddings = False
            if hasattr(self.model, "has_own_embed_tokens"):
                # EAGLE model
                if not self.model.has_own_embed_tokens:
                    share_embeddings = True
                    logger.info(
                        "Detected EAGLE model without its own embed_tokens in the"
                        " checkpoint. Sharing target model embedding weights with the"
                        " draft model."
                    )
                elif (
                    isinstance(target_embed_tokens.weight, torch.Tensor)
                    and isinstance(self.model.model.embed_tokens.weight, torch.Tensor)
1121
1122
1123
                    # TODO: Offload to CPU for comparison to avoid extra GPU memory
                    # usage in CI testing environments with limited GPU memory
                    and torch.equal(
1124
1125
                        target_embed_tokens.weight.cpu(),
                        self.model.model.embed_tokens.weight.cpu(),
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
                    )
                ):
                    share_embeddings = True
                    logger.info(
                        "Detected EAGLE model with embed_tokens identical to the target"
                        " model. Sharing target model embedding weights with the draft"
                        " model."
                    )
                else:
                    logger.info(
                        "Detected EAGLE model with distinct embed_tokens weights. "
                        "Keeping separate embedding weights from the target model."
                    )
1139
            else:
1140
1141
                # MTP model
                share_embeddings = True
1142
                logger.info(
1143
1144
                    "Detected MTP model. "
                    "Sharing target model embedding weights with the draft model."
1145
                )
1146
1147
1148
1149
1150

            if share_embeddings:
                if hasattr(self.model.model, "embed_tokens"):
                    del self.model.model.embed_tokens
                self.model.model.embed_tokens = target_embed_tokens
1151
        else:
1152
            logger.info(
1153
                "The draft model's vocab embedding will be loaded separately"
1154
1155
                " from the target model."
            )
1156
1157

        # share lm_head with the target model if needed
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
        share_lm_head = False
        if hasattr(self.model, "has_own_lm_head"):
            # EAGLE model
            if not self.model.has_own_lm_head:
                share_lm_head = True
                logger.info(
                    "Detected EAGLE model without its own lm_head in the checkpoint. "
                    "Sharing target model lm_head weights with the draft model."
                )
            elif (
                hasattr(target_language_model, "lm_head")
                and isinstance(target_language_model.lm_head.weight, torch.Tensor)
                and isinstance(self.model.lm_head.weight, torch.Tensor)
1171
1172
                # TODO: Offload to CPU for comparison to avoid extra GPU memory
                # usage in CI testing environments with limited GPU memory
1173
                and torch.equal(
1174
1175
                    target_language_model.lm_head.weight.cpu(),
                    self.model.lm_head.weight.cpu(),
1176
                )
1177
            ):
1178
                share_lm_head = True
1179
                logger.info(
1180
1181
                    "Detected EAGLE model with lm_head identical to the target model. "
                    "Sharing target model lm_head weights with the draft model."
1182
                )
1183
1184
            else:
                logger.info(
1185
1186
                    "Detected EAGLE model with distinct lm_head weights. "
                    "Keeping separate lm_head weights from the target model."
1187
                )
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
        else:
            # MTP model
            share_lm_head = True
            logger.info(
                "Detected MTP model. "
                "Sharing target model lm_head weights with the draft model."
            )

        if share_lm_head and hasattr(target_language_model, "lm_head"):
            if hasattr(self.model, "lm_head"):
                del self.model.lm_head
            self.model.lm_head = target_language_model.lm_head
1200

1201
1202
1203
1204
    @torch.inference_mode()
    def dummy_run(
        self,
        num_tokens: int,
1205
1206
        use_cudagraphs: bool = True,
        is_graph_capturing: bool = False,
1207
    ) -> None:
Rémi Delacourt's avatar
Rémi Delacourt committed
1208
1209
1210
1211
        # FIXME: when using tree-based specdec, adjust number of forward-passes
        # according to the depth of the tree.
        for fwd_idx in range(
            self.num_speculative_tokens if not is_graph_capturing else 1
1212
        ):
Rémi Delacourt's avatar
Rémi Delacourt committed
1213
1214
            if fwd_idx <= 1:
                num_tokens_dp_padded, num_tokens_across_dp = self._pad_batch_across_dp(
1215
                    num_tokens_unpadded=num_tokens, num_tokens_padded=num_tokens
Rémi Delacourt's avatar
Rémi Delacourt committed
1216
                )
1217
1218
1219
1220
                cudagraph_runtime_mode, batch_desc = self.cudagraph_dispatcher.dispatch(
                    num_tokens_dp_padded
                )
                num_input_tokens = batch_desc.num_tokens
Rémi Delacourt's avatar
Rémi Delacourt committed
1221
1222
                if num_tokens_across_dp is not None:
                    num_tokens_across_dp[self.dp_rank] = num_input_tokens
1223

Rémi Delacourt's avatar
Rémi Delacourt committed
1224
1225
1226
1227
1228
            with set_forward_context(
                None,
                self.vllm_config,
                num_tokens=num_input_tokens,
                num_tokens_across_dp=num_tokens_across_dp,
1229
                cudagraph_runtime_mode=cudagraph_runtime_mode,
Rémi Delacourt's avatar
Rémi Delacourt committed
1230
1231
1232
1233
1234
1235
1236
1237
            ):
                if self.supports_mm_inputs:
                    input_ids = None
                    inputs_embeds = self.inputs_embeds[:num_input_tokens]
                else:
                    input_ids = self.input_ids[:num_input_tokens]
                    inputs_embeds = None

1238
                kwargs = dict(
Rémi Delacourt's avatar
Rémi Delacourt committed
1239
1240
1241
1242
                    input_ids=input_ids,
                    positions=self._get_positions(num_input_tokens),
                    inputs_embeds=inputs_embeds,
                )
1243
1244
1245
                if self.pass_hidden_states_to_model:
                    kwargs["hidden_states"] = self.hidden_states[:num_input_tokens]
                self.model(**kwargs)
1246

1247
    def _get_attention_metadata_builder(self) -> AttentionMetadataBuilder:
1248
        """Find and return the attention metadata builders for EAGLE layers.
1249

1250
1251
        Returns:
            The metadata builders for EAGLE layers.
1252

1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
        Raises:
            AssertionError: If no metadata builders are found for EAGLE layers.
        """
        builder = None
        chosen_layer = self.attn_layer_names[0]

        for kv_cache_group in self.runner.attn_groups:
            for attn_group in kv_cache_group:
                if chosen_layer in attn_group.layer_names:
                    builder = attn_group.get_metadata_builder()
                    break
            if builder is not None:
                break

        assert builder is not None, (
1268
1269
            "Failed to find attention metadata builder for EAGLE layers."
        )
1270
1271
        return builder

1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
    def _get_eagle3_use_aux_hidden_state_from_config(self) -> bool:
        """
        Some eagle3 heads (e.g., nvidia/gpt-oss-120b-Eagle3-v2) do not use auxiliary
        hidden states and directly uses the last layer output just like eagle1.
        They might indicate this by setting "use_aux_hidden_state" to False
        inside the "eagle_config" dict of their hf_config.
        """
        if self.method != "eagle3":
            return False
        # Assume that eagle3 heads use aux hidden states by default
        use_aux_hidden_state = True
        eagle_config = getattr(self.draft_model_config.hf_config, "eagle_config", None)
        if eagle_config is not None:
            use_aux_hidden_state = eagle_config.get("use_aux_hidden_state", True)
        return use_aux_hidden_state

1288
    def validate_same_kv_cache_group(self, kv_cache_config: KVCacheConfig) -> None:
1289
        """
1290
1291
        Validate that all drafting layers belong to the same KVCacheGroup.
        Need this assumption to ensure all drafting layers can use the
1292
1293
1294
1295
1296
1297
1298
        same AttentionMetadata.
        May extend to multiple AttentionMetadata in the future.
        """
        kv_cache_groups: dict[str, int] = {}
        for id, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups):
            for layer_name in kv_cache_group.layer_names:
                kv_cache_groups[layer_name] = id
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
        assert (
            len(
                set(
                    [
                        kv_cache_groups[layer_name]
                        for layer_name in self.attn_layer_names
                    ]
                )
            )
            == 1
1309
        ), "All drafting layers should belong to the same kv cache group"
1310

Rémi Delacourt's avatar
Rémi Delacourt committed
1311
1312
1313
1314
1315
1316
    def _pad_batch_across_dp(
        self,
        num_tokens_unpadded: int,
        num_tokens_padded: int,
    ) -> tuple[int, torch.Tensor]:
        # TODO(Flechman): support DBO ubatching
1317
        should_ubatch, num_toks_across_dp, _ = coordinate_batch_across_dp(
Rémi Delacourt's avatar
Rémi Delacourt committed
1318
1319
1320
            num_tokens_unpadded=num_tokens_unpadded,
            parallel_config=self.vllm_config.parallel_config,
            allow_microbatching=False,
1321
1322
            allow_dp_padding=self.cudagraph_dispatcher.cudagraph_mode
            != CUDAGraphMode.NONE,
Rémi Delacourt's avatar
Rémi Delacourt committed
1323
1324
1325
1326
            num_tokens_padded=num_tokens_padded,
            uniform_decode=None,
            num_scheduled_tokens_per_request=None,
        )
1327
        assert not should_ubatch, "DBO ubatching not implemented for EAGLE"
Rémi Delacourt's avatar
Rémi Delacourt committed
1328
1329
1330
1331
1332
1333

        num_tokens_dp_padded = num_tokens_padded
        if num_toks_across_dp is not None:
            num_tokens_dp_padded = int(num_toks_across_dp[self.dp_rank].item())
        return num_tokens_dp_padded, num_toks_across_dp

1334

1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
class EagleProposer(SpecDecodeBaseProposer):
    def __init__(
        self,
        vllm_config: VllmConfig,
        device: torch.device,
        runner=None,
    ):
        super().__init__(
            vllm_config,
            device,
            pass_hidden_states_to_model=True,
            runner=runner,
        )


1350
1351
1352
1353
# NOTE(woosuk): Currently, the below code is not used and we always use argmax
# to sample the draft tokens. We will use this after we find a way to manage
# the draft prob tensor.
# Refer to https://github.com/vllm-project/vllm/pull/16899 for the details.
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
# FIXME(woosuk): The logic here is duplicated with the main sampling code.
# We should refactor this to reuse the same sampling implementation.
def compute_probs_and_sample_next_token(
    logits: torch.Tensor,
    sampling_metadata: SamplingMetadata,
) -> tuple[torch.Tensor, torch.Tensor]:
    if sampling_metadata.all_greedy:
        # For greedy requests, draft_probs is not used in rejection sampling.
        # Therefore, we can just return the logits.
        probs = logits
        next_token_ids = logits.argmax(dim=-1)
        return next_token_ids, probs

1367
1368
1369
1370
1371
1372
1373
1374
1375
    assert sampling_metadata.temperature is not None

    # Use epsilon comparison to detect greedy sampling (temperature ~ 0.0)
    # consistent with sampler.py's _SAMPLING_EPS threshold
    temperature = sampling_metadata.temperature
    # Avoid division by zero if there are greedy requests.
    if not sampling_metadata.all_random:
        is_greedy = temperature < _SAMPLING_EPS
        temperature = torch.where(is_greedy, 1.0, temperature)
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
    logits.div_(temperature.view(-1, 1))
    probs = logits.softmax(dim=-1, dtype=torch.float32)

    # NOTE(woosuk): Currently, we ignore most of the sampling parameters in
    # generating the draft tokens. We only use the temperature. While this
    # could degrade the acceptance rate, it does not affect the distribution
    # of the generated tokens after rejection sampling.

    # TODO(woosuk): Consider seeds.
    q = torch.empty_like(probs)
    q.exponential_()
1387
1388
1389
    # NOTE(woosuk): We shouldn't use `probs.div_(q)` because the draft_probs
    # will be used later for rejection sampling.
    next_token_ids = probs.div(q).argmax(dim=-1).view(-1)
1390
1391
    if not sampling_metadata.all_random:
        greedy_token_ids = probs.argmax(dim=-1)
1392
        next_token_ids = torch.where(is_greedy, greedy_token_ids, next_token_ids)
1393
    return next_token_ids, probs