eagle.py 64.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
import ast
from dataclasses import replace
5
from importlib.util import find_spec
6

7
import numpy as np
8
9
10
import torch
import torch.nn as nn

王敏's avatar
王敏 committed
11
import vllm.envs as envs
12
13
14
15
16
from vllm.config import (
    CUDAGraphMode,
    VllmConfig,
    get_layers_from_vllm_config,
)
17
from vllm.distributed.parallel_state import get_pp_group
18
from vllm.forward_context import set_forward_context
19
from vllm.logger import init_logger
20
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
21
from vllm.model_executor.model_loader import get_model
22
from vllm.model_executor.models import supports_multimodal
23
from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache
24
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
25
from vllm.multimodal import MULTIMODAL_REGISTRY
26
from vllm.platforms import current_platform
27
from vllm.triton_utils import triton
28
from vllm.utils.platform_utils import is_pin_memory_available
29
30
31
32
from vllm.v1.attention.backend import (
    AttentionMetadataBuilder,
    CommonAttentionMetadata,
)
33
from vllm.v1.attention.backends.registry import AttentionBackendEnum
34
35
36
37
from vllm.v1.attention.backends.tree_attn import (
    TreeAttentionMetadata,
    TreeAttentionMetadataBuilder,
)
38
from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata
39
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
40
from vllm.v1.kv_cache_interface import KVCacheConfig
41
from vllm.v1.sample.metadata import SamplingMetadata
42
from vllm.v1.sample.sampler import _SAMPLING_EPS
43
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
44
45
46
47
from vllm.v1.spec_decode.utils import (
    eagle_prepare_inputs_padded_kernel,
    eagle_prepare_next_token_padded_kernel,
)
48
from vllm.v1.utils import CpuGpuBuffer
Rémi Delacourt's avatar
Rémi Delacourt committed
49
from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
50
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
51

52
53
logger = init_logger(__name__)

54
55
PADDING_SLOT_ID = -1

56

57
class SpecDecodeBaseProposer:
58
59
60
61
    def __init__(
        self,
        vllm_config: VllmConfig,
        device: torch.device,
62
        pass_hidden_states_to_model: bool,
Jiayi Yao's avatar
Jiayi Yao committed
63
        runner=None,
64
65
    ):
        self.vllm_config = vllm_config
66
        self.speculative_config = vllm_config.speculative_config
67
        assert self.speculative_config is not None
68
69
        self.draft_model_config = self.speculative_config.draft_model_config
        self.method = self.speculative_config.method
70
        self.pass_hidden_states_to_model = pass_hidden_states_to_model
71

Jiayi Yao's avatar
Jiayi Yao committed
72
        self.runner = runner
73
        self.device = device
74
        self.dtype = vllm_config.model_config.dtype
75
        self.max_model_len = vllm_config.model_config.max_model_len
Rémi Delacourt's avatar
Rémi Delacourt committed
76
        self.dp_rank = vllm_config.parallel_config.data_parallel_rank
77
        self.num_speculative_tokens = self.speculative_config.num_speculative_tokens
78
79
80
81
82
        # The drafter can get longer sequences than the target model.
        max_batch_size = vllm_config.scheduler_config.max_num_seqs
        self.max_num_tokens = (
            vllm_config.scheduler_config.max_num_batched_tokens + max_batch_size
        )
83
        self.token_arange_np = np.arange(self.max_num_tokens)
84
85
86
87
        # We need to get the hidden size from the draft model config because
        # the draft model's hidden size can be different from the target model's
        # hidden size (e.g., Llama 3.3 70B).
        self.hidden_size = self.draft_model_config.get_hidden_size()
88
        self.inputs_embeds_size = self.draft_model_config.get_inputs_embeds_size()
89

90
91
92
        # Multi-modal data support
        self.mm_registry = MULTIMODAL_REGISTRY
        self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
93
94
            vllm_config.model_config
        )
95

96
97
        self.attn_metadata_builder: AttentionMetadataBuilder | None = None
        self.draft_indexer_metadata_builder: AttentionMetadataBuilder | None = None
98
99
        self.attn_layer_names: list[str] = []
        self.indexer_layer_names: list[str] = []
100
101
102
        self.eagle3_use_aux_hidden_state: bool = (
            self._get_eagle3_use_aux_hidden_state_from_config()
        )
103

104
        self.compilation_config = self.vllm_config.compilation_config
105
106
107
108
109
110

        # Cudagraph dispatcher for PIECEWISE-only dispatching in eagle.
        # Keys are initialized later via initialize_cudagraph_keys() called from
        # gpu_model_runner._check_and_update_cudagraph_mode after
        # adjust_cudagraph_sizes_for_spec_decode is called.
        self.cudagraph_dispatcher = CudagraphDispatcher(self.vllm_config)
111
112

        # persistent buffers for cuda graph
113
114
115
        self.input_ids = torch.zeros(
            self.max_num_tokens, dtype=torch.int32, device=device
        )
116
117
118
        # Use draft model's M-RoPE setting, not target model's
        # Draft models may be text-only even if target is multimodal
        self.uses_mrope = self.draft_model_config.uses_mrope
119
        if self.uses_mrope:
120
121
122
123
124
125
126
127
128
129
            # NOTE: `mrope_positions` is implemented with one additional dummy
            # position on purpose to make it non-contiguous so that it can work
            # with torch compile.
            # See detailed explanation in https://github.com/vllm-project/vllm/pull/12128#discussion_r1926431923

            # NOTE: When M-RoPE is enabled, position ids are 3D regardless of
            # the modality of inputs. For text-only inputs, each dimension has
            # identical position IDs, making M-RoPE functionally equivalent to
            # 1D-RoPE.
            # See page 5 of https://arxiv.org/abs/2409.12191
130
            self.mrope_positions = torch.zeros(
131
                (3, self.max_num_tokens + 1), dtype=torch.int64, device=device
132
            )
133
134
        else:
            # RoPE need (max_num_tokens,)
135
136
137
            self.positions = torch.zeros(
                self.max_num_tokens, dtype=torch.int64, device=device
            )
138
        self.hidden_states = torch.zeros(
139
140
            (self.max_num_tokens, self.hidden_size), dtype=self.dtype, device=device
        )
141

142
143
144
        # We need +1 here because the arange is used to set query_start_loc,
        # which has one more element than batch_size.
        max_num_slots_for_arange = max(max_batch_size + 1, self.max_num_tokens)
145
146
147
        self.arange = torch.arange(
            max_num_slots_for_arange, device=device, dtype=torch.int32
        )
148

149
        self.inputs_embeds = torch.zeros(
150
            (self.max_num_tokens, self.inputs_embeds_size),
151
            dtype=self.dtype,
152
            device=device,
153
        )
154

155
156
157
158
159
        self.backup_next_token_ids = CpuGpuBuffer(
            max_batch_size,
            dtype=torch.int32,
            pin_memory=is_pin_memory_available(),
            device=device,
160
161
            with_numpy=True,
        )
162

163
164
165
166
        self._slot_mapping_buffer = torch.zeros(
            self.max_num_tokens, dtype=torch.int64, device=device
        )

167
        # Determine allowed attention backends once during initialization.
168
        self.allowed_attn_types: tuple | None = None
zhuwenwen's avatar
zhuwenwen committed
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
        # if current_platform.is_rocm():
        #     from vllm.v1.attention.backends.rocm_attn import RocmAttentionMetadata

        #     rocm_types = [
        #         TritonAttentionMetadata,
        #         RocmAttentionMetadata,
        #     ]
        #     # ROCM_AITER_FA is an optional backend
        #     if find_spec(
        #         AttentionBackendEnum.ROCM_AITER_FA.get_path(include_classname=False)
        #     ):
        #         from vllm.v1.attention.backends.rocm_aiter_fa import (
        #             AiterFlashAttentionMetadata,
        #         )

        #         rocm_types.append(AiterFlashAttentionMetadata)

        #     # TRITON_MLA backend support for MLA models (e.g., DeepSeek)
        #     from vllm.model_executor.layers.attention.mla_attention import (
        #         MLACommonMetadata,
        #     )
190

zhuwenwen's avatar
zhuwenwen committed
191
        #     rocm_types.append(MLACommonMetadata)
192

zhuwenwen's avatar
zhuwenwen committed
193
194
        #     # FlexAttention backend support
        #     from vllm.v1.attention.backends.flex_attention import FlexAttentionMetadata
195

zhuwenwen's avatar
zhuwenwen committed
196
        #     rocm_types.append(FlexAttentionMetadata)
197

zhuwenwen's avatar
zhuwenwen committed
198
        #     self.allowed_attn_types = tuple(rocm_types)
199

200
201
        # Parse the speculative token tree.
        spec_token_tree = self.speculative_config.speculative_token_tree
202
        self.tree_choices: list[tuple[int, ...]] = ast.literal_eval(spec_token_tree)
203
204
205
206
207
208
209
210
        tree_depth = len(self.tree_choices[-1])
        # Precompute per-level properties of the tree.
        num_drafts_per_level = [0] * tree_depth
        for node in self.tree_choices:
            num_drafts_per_level[len(node) - 1] += 1
        self.cu_drafts_per_level = [num_drafts_per_level[0]]
        self.child_drafts_per_level = [num_drafts_per_level[0]]
        for level in range(1, tree_depth):
211
212
213
214
215
216
            self.cu_drafts_per_level.append(
                self.cu_drafts_per_level[-1] + num_drafts_per_level[level]
            )
            self.child_drafts_per_level.append(
                num_drafts_per_level[level] // num_drafts_per_level[level - 1]
            )
217
218
        # Precompute draft position offsets in flattened tree.
        self.tree_draft_pos_offsets = torch.arange(
219
            1, len(self.tree_choices) + 1, device=device, dtype=torch.int32
220
        ).repeat(max_batch_size, 1)
221

222
223
224
225
226
227
228
229
230
    def _get_positions(self, num_tokens: int):
        if self.uses_mrope:
            return self.mrope_positions[:, :num_tokens]
        return self.positions[:num_tokens]

    def _set_positions(self, num_tokens: int, positions: torch.Tensor):
        if self.uses_mrope:
            self.mrope_positions[:, :num_tokens] = positions
        else:
231
232
233
234
235
            # Convert M-RoPE positions if target model uses M-RoPE
            # but draft doesn't, For text inputs, all M-RoPE
            # dimensions are identical
            if self.vllm_config.model_config.uses_mrope:
                positions = positions[0]
236
237
            self.positions[:num_tokens] = positions

238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
    def _get_slot_mapping(
        self,
        num_tokens: int,
        slot_mapping: torch.Tensor | None = None,
    ) -> dict[str, torch.Tensor]:
        """Return slot_mapping dict for EAGLE layers.

        If slot_mapping is provided, copies it into the buffer first.
        """
        if slot_mapping is not None:
            num_actual = slot_mapping.shape[0]
            self._slot_mapping_buffer[:num_actual].copy_(slot_mapping)
            if num_tokens > num_actual:
                self._slot_mapping_buffer[num_actual:num_tokens].fill_(PADDING_SLOT_ID)

        view = self._slot_mapping_buffer[:num_tokens]
        return {name: view for name in self.attn_layer_names + self.indexer_layer_names}

256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
    def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode) -> None:
        """Initialize cudagraph dispatcher keys for eagle.

        Eagle only supports PIECEWISE cudagraphs (via mixed_mode).
        This should be called after adjust_cudagraph_sizes_for_spec_decode.
        """
        if (
            not self.speculative_config.enforce_eager
            and cudagraph_mode.mixed_mode()
            in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL]
        ):
            eagle_cudagraph_mode = CUDAGraphMode.PIECEWISE
        else:
            eagle_cudagraph_mode = CUDAGraphMode.NONE

        self.cudagraph_dispatcher.initialize_cudagraph_keys(eagle_cudagraph_mode)

273
274
275
276
    def propose(
        self,
        # [num_tokens]
        target_token_ids: torch.Tensor,
277
        # [num_tokens] or [3, num_tokens] when M-RoPE is enabled
278
279
280
281
282
        target_positions: torch.Tensor,
        # [num_tokens, hidden_size]
        target_hidden_states: torch.Tensor,
        # [batch_size]
        next_token_ids: torch.Tensor,
283
        last_token_indices: torch.Tensor | None,
284
        common_attn_metadata: CommonAttentionMetadata,
285
        sampling_metadata: SamplingMetadata,
286
        mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None,
287
        num_rejected_tokens_gpu: torch.Tensor | None = None,
288
289
290
        slot_mappings: dict[str, torch.Tensor]
        | list[dict[str, torch.Tensor]]
        | None = None,
291
    ) -> torch.Tensor:
292
        batch_size = common_attn_metadata.batch_size()
293

294
295
296
        if self.method == "eagle3":
            assert isinstance(self.model, Eagle3LlamaForCausalLM)
            target_hidden_states = self.model.combine_hidden_states(
297
298
                target_hidden_states
            )
299
            assert target_hidden_states.shape[-1] == self.hidden_size
300
301
302
303
304
305
306
307
308
309
310

        num_tokens, last_token_indices, common_attn_metadata = (
            self.set_inputs_first_pass(
                target_token_ids=target_token_ids,
                next_token_ids=next_token_ids,
                target_positions=target_positions,
                last_token_indices=last_token_indices,
                cad=common_attn_metadata,
                num_rejected_tokens_gpu=num_rejected_tokens_gpu,
            )
        )
311

312
        assert self.runner is not None
Jiayi Yao's avatar
Jiayi Yao committed
313

314
315
316
317
318
        if self.attn_metadata_builder is None:
            attn_metadata_builder = self._get_attention_metadata_builder()
        else:
            attn_metadata_builder = self.attn_metadata_builder

319
        attn_metadata = attn_metadata_builder.build_for_drafting(
320
321
            common_attn_metadata=common_attn_metadata, draft_index=0
        )
322
323
324
325
326
327
        # FIXME: support hybrid kv for draft model (remove separate indexer)
        if self.draft_indexer_metadata_builder:
            draft_indexer_metadata = (
                self.draft_indexer_metadata_builder.build_for_drafting(
                    common_attn_metadata=common_attn_metadata,
                    draft_index=0,
328
329
                )
            )
330
331
        else:
            draft_indexer_metadata = None
332
333
334
335
336
        # At this moment, we assume all eagle layers belong to the same KV
        # cache group, thus using the same attention metadata.
        per_layer_attn_metadata = {}
        for layer_name in self.attn_layer_names:
            per_layer_attn_metadata[layer_name] = attn_metadata
337

338
339
340
341
        for layer_name in self.indexer_layer_names:
            assert draft_indexer_metadata is not None
            per_layer_attn_metadata[layer_name] = draft_indexer_metadata

Rémi Delacourt's avatar
Rémi Delacourt committed
342
        num_tokens_dp_padded, num_tokens_across_dp = self._pad_batch_across_dp(
343
            num_tokens_unpadded=num_tokens, num_tokens_padded=num_tokens
Rémi Delacourt's avatar
Rémi Delacourt committed
344
345
        )

346
347
348
349
        cudagraph_runtime_mode, batch_desc = self.cudagraph_dispatcher.dispatch(
            num_tokens_dp_padded
        )
        num_input_tokens = batch_desc.num_tokens
Rémi Delacourt's avatar
Rémi Delacourt committed
350
351
352
        if num_tokens_across_dp is not None:
            num_tokens_across_dp[self.dp_rank] = num_input_tokens

353
354
355
356
        if self.pass_hidden_states_to_model:
            # target_hidden_states and self.hidden_states can have different
            # hidden dims. E.g. large target model and small draft model.
            self.hidden_states[:num_tokens] = target_hidden_states
357
358
359
360

        if self.supports_mm_inputs:
            mm_embeds, is_mm_embed = mm_embed_inputs or (None, None)

361
            self.inputs_embeds[:num_tokens] = self.model.embed_input_ids(
362
363
364
                self.input_ids[:num_tokens],
                multimodal_embeddings=mm_embeds,
                is_multimodal=is_mm_embed,
365
            )
366

367
            input_ids = None
368
            inputs_embeds = self.inputs_embeds[:num_input_tokens]
369
370
        else:
            input_ids = self.input_ids[:num_input_tokens]
371
            inputs_embeds = None
372

373
374
375
376
377
378
379
380
        model_kwargs = {
            "input_ids": input_ids,
            "positions": self._get_positions(num_input_tokens),
            "inputs_embeds": inputs_embeds,
        }
        if self.pass_hidden_states_to_model:
            model_kwargs["hidden_states"] = self.hidden_states[:num_input_tokens]

381
        with set_forward_context(
382
383
384
            per_layer_attn_metadata,
            self.vllm_config,
            num_tokens=num_input_tokens,
Rémi Delacourt's avatar
Rémi Delacourt committed
385
            num_tokens_across_dp=num_tokens_across_dp,
386
            cudagraph_runtime_mode=cudagraph_runtime_mode,
387
388
389
            slot_mapping=self._get_slot_mapping(
                num_input_tokens, common_attn_metadata.slot_mapping
            ),
390
        ):
391
392
            ret_hidden_states = self.model(**model_kwargs)
            if not self.model_returns_tuple():
Jiayi Yao's avatar
Jiayi Yao committed
393
                last_hidden_states = ret_hidden_states
394
                hidden_states = last_hidden_states
Jiayi Yao's avatar
Jiayi Yao committed
395
396
            else:
                last_hidden_states, hidden_states = ret_hidden_states
397

398
        sample_hidden_states = last_hidden_states[last_token_indices]
399
        logits = self.model.compute_logits(sample_hidden_states)
400

王敏's avatar
王敏 committed
401
402
403
        if envs.VLLM_REJECT_SAMPLE_OPT:
            draft_prob = logits.softmax(dim=-1, dtype=torch.float32)

404
405
406
        # Early exit if there is only one draft token to be generated.
        if self.num_speculative_tokens == 1:
            draft_token_ids = logits.argmax(dim=-1)
王敏's avatar
王敏 committed
407
408
409
410

            if envs.VLLM_REJECT_SAMPLE_OPT:
                return draft_token_ids.view(-1, 1), draft_prob.view(-1, 1, logits.shape[-1])

411
            return draft_token_ids.view(-1, 1)
412

413
        if self.uses_mrope:
414
            positions = self.mrope_positions[:, last_token_indices]
415
        else:
416
            positions = self.positions[last_token_indices]
417
418
419
420
421
422
        if self.method in (
            "deepseek_mtp",
            "ernie_mtp",
            "longcat_flash_mtp",
            "pangu_ultra_moe_mtp",
        ):
zhuwenwen's avatar
zhuwenwen committed
423
424
425
            hidden_states = self.hidden_states[last_token_indices]
        else:
            hidden_states = hidden_states[last_token_indices]
426
427
428

        if isinstance(attn_metadata, TreeAttentionMetadata):
            # Draft using tree attention.
429
430
431
432
433
434
            draft_token_ids_list = self.propose_tree(
                batch_size=batch_size,
                logits=logits,
                positions=positions,
                hidden_states=hidden_states,
                common_attn_metadata=common_attn_metadata,
435
                slot_mappings=slot_mappings,
436
437
438
439
            )
            # [batch_size, num_tree_tokens]
            return torch.cat(draft_token_ids_list, dim=1)

440
        draft_token_ids = logits.argmax(dim=-1)
441

442
443
444
        if self.allowed_attn_types is not None and not isinstance(
            attn_metadata, self.allowed_attn_types
        ):
445
446
447
448
            raise ValueError(
                f"Unsupported attention metadata type for speculative "
                "decoding with num_speculative_tokens > 1: "
                f"{type(attn_metadata)}. Supported types are: "
449
450
                f"{self.allowed_attn_types}"
            )
451

452
453
        # Generate the remaining draft tokens.
        draft_token_ids_list = [draft_token_ids]
454

Rémi Delacourt's avatar
Rémi Delacourt committed
455
        batch_size_dp_padded, batch_size_across_dp = self._pad_batch_across_dp(
456
            num_tokens_unpadded=batch_size, num_tokens_padded=batch_size
Rémi Delacourt's avatar
Rémi Delacourt committed
457
458
        )

459
460
461
462
        cudagraph_runtime_mode, batch_desc = self.cudagraph_dispatcher.dispatch(
            batch_size_dp_padded
        )
        input_batch_size = batch_desc.num_tokens
Rémi Delacourt's avatar
Rémi Delacourt committed
463
464
        if batch_size_across_dp is not None:
            batch_size_across_dp[self.dp_rank] = input_batch_size
465

466
467
        common_attn_metadata.num_actual_tokens = batch_size
        common_attn_metadata.max_query_len = 1
468
        common_attn_metadata.query_start_loc = self.arange[: batch_size + 1]
469
        common_attn_metadata.query_start_loc_cpu = torch.from_numpy(
470
471
            self.token_arange_np[: batch_size + 1]
        ).clone()
472
473
474
475
476
477
478
479
480
481
482

        # In padded drafter batch, we need to adjust the sequence lengths
        # to remove the "padding" (i.e. rejected tokens).
        # Only apply this adjustment when we have rejected tokens
        # (i.e., not the first proposal).
        if self.num_speculative_tokens > 1 and num_rejected_tokens_gpu is not None:
            common_attn_metadata.seq_lens -= num_rejected_tokens_gpu
            # Invalidate the CPU-side shadows to avoid H<>D sync.
            common_attn_metadata._seq_lens_cpu = None
            common_attn_metadata._num_computed_tokens_cpu = None

王敏's avatar
王敏 committed
483
484
485
        if envs.VLLM_REJECT_SAMPLE_OPT:
            draft_probs_list = [draft_prob]

486
        for token_index in range(self.num_speculative_tokens - 1):
487
            # Update the inputs.
488
489
490
            # cast to int32 is crucial when eagle model is compiled.
            # tensor.argmax() returns int64 by default.
            input_ids = draft_token_ids_list[-1].int()
491
492
493
494
495
496
497
498
499
500
501
502
503
            if self.uses_mrope:
                positions += 1
                # NOTE(woosuk): We should handle the case where the draft model
                # generates tokens beyond the max model length.
                # Since it is complex to remove such requests from the batch,
                # we keep them in the batch but adjust the position ids
                # and slot mappings to avoid the
                # out-of-range access during the model execution.
                # The draft tokens generated with this adjustment
                # should be ignored.
                exceeds_max_model_len = positions[0] >= self.max_model_len
                # Mask out the position ids that exceed the max model length.
                # Otherwise, we may get out-of-range error in RoPE.
504
505
506
507
508
                clamped_positions = torch.where(
                    exceeds_max_model_len.unsqueeze(0),
                    torch.zeros_like(positions),
                    positions,
                )
509
            else:
510
511
                positions += 1
                exceeds_max_model_len = positions >= self.max_model_len
512
                clamped_positions = torch.where(exceeds_max_model_len, 0, positions)
513
514
515
            # For data integrity when async scheduling, we shouldn't use in place
            # operations in case they are modified in next step's `prepare_input`
            # of main model.
516
            # Increment the sequence lengths.
517
            common_attn_metadata.seq_lens += 1
518
519
            # For the requests that exceed the max model length, we set the
            # sequence length to 1 to minimize their overheads in attention.
520
            common_attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1)
521
522
523
524
525
526
            # Increment the maximum sequence length. We increment max_seq_len
            # unconditionally even though some seq_lens may have been capped above,
            # as max_seq_len serves as an upper bound for sequence lengths.
            common_attn_metadata.max_seq_len = min(
                common_attn_metadata.max_seq_len + 1, self.max_model_len
            )
527

528
529
530
531
532
533
            # Also update the CPU-side shadow; NOTE: this is hacky and should be
            # removed in when common_attn_metadata.seq_lens_cpu is deprecated.
            if common_attn_metadata._seq_lens_cpu is not None:
                common_attn_metadata._seq_lens_cpu += 1
            if common_attn_metadata._num_computed_tokens_cpu is not None:
                common_attn_metadata._num_computed_tokens_cpu += 1
534

535
            # Compute the slot mapping.
536
            block_size = attn_metadata_builder.kv_cache_spec.block_size
537
538
            if self.uses_mrope:
                # all dimensions of positions are the same
539
                block_numbers = clamped_positions[0] // block_size
540
            else:
541
                block_numbers = clamped_positions // block_size
542
            block_ids = common_attn_metadata.block_table_tensor.gather(
543
544
                dim=1, index=block_numbers.view(-1, 1)
            )
545
            block_ids = block_ids.view(-1)
546
547
            if self.uses_mrope:
                common_attn_metadata.slot_mapping = (
548
                    block_ids * block_size + clamped_positions[0] % block_size
549
                )
550
551
            else:
                common_attn_metadata.slot_mapping = (
552
                    block_ids * block_size + clamped_positions % block_size
553
                )
554
555
556
            # Mask out the slot mappings that exceed the max model length.
            # Otherwise, the KV cache will be inadvertently updated with the
            # padding tokens.
557
            common_attn_metadata.slot_mapping.masked_fill_(
558
559
                exceeds_max_model_len, PADDING_SLOT_ID
            )
560
561

            # Rebuild attention metadata
562
            attn_metadata = attn_metadata_builder.build_for_drafting(  # type: ignore
563
564
                common_attn_metadata=common_attn_metadata, draft_index=token_index + 1
            )
王敏's avatar
王敏 committed
565
566
567
568
569
570
571
572
573
574
575

            if self.draft_indexer_metadata_builder:
                draft_indexer_metadata = (
                    self.draft_indexer_metadata_builder.build_for_drafting(
                        common_attn_metadata=common_attn_metadata,
                        draft_index=token_index + 1,
                    )
                )
            else:
                draft_indexer_metadata = None

576
577
            for layer_name in self.attn_layer_names:
                per_layer_attn_metadata[layer_name] = attn_metadata
578

王敏's avatar
王敏 committed
579
580
581
            for layer_name in self.indexer_layer_names:
                per_layer_attn_metadata[layer_name] = draft_indexer_metadata

582
583
            # copy inputs to buffer for cudagraph
            self.input_ids[:batch_size] = input_ids
584
            self._set_positions(batch_size, clamped_positions)
585
            self.hidden_states[:batch_size] = hidden_states
586
            if self.supports_mm_inputs:
587
                self.inputs_embeds[:batch_size] = self.model.embed_input_ids(input_ids)
588

589
                input_ids = None
590
                inputs_embeds = self.inputs_embeds[:input_batch_size]
591
592
            else:
                input_ids = self.input_ids[:input_batch_size]
593
                inputs_embeds = None
594

595
            # Run the model.
596
597
598
599
600
601
602
603
            model_kwargs = {
                "input_ids": input_ids,
                "positions": self._get_positions(input_batch_size),
                "inputs_embeds": inputs_embeds,
            }
            if self.pass_hidden_states_to_model:
                model_kwargs["hidden_states"] = self.hidden_states[:input_batch_size]

604
            with set_forward_context(
605
606
607
                per_layer_attn_metadata,
                self.vllm_config,
                num_tokens=input_batch_size,
Rémi Delacourt's avatar
Rémi Delacourt committed
608
                num_tokens_across_dp=batch_size_across_dp,
609
                cudagraph_runtime_mode=cudagraph_runtime_mode,
610
611
612
                slot_mapping=self._get_slot_mapping(
                    input_batch_size, common_attn_metadata.slot_mapping
                ),
613
            ):
614
615
                ret_hidden_states = self.model(**model_kwargs)
                if not self.model_returns_tuple():
616
                    last_hidden_states = ret_hidden_states
617
                    hidden_states = ret_hidden_states
618
619
                else:
                    last_hidden_states, hidden_states = ret_hidden_states
620

621
            hidden_states = hidden_states[:batch_size]
622
            logits = self.model.compute_logits(last_hidden_states[:batch_size])
623
            draft_token_ids = logits.argmax(dim=-1)
624
625
            draft_token_ids_list.append(draft_token_ids)

王敏's avatar
王敏 committed
626
627
628
629
            if envs.VLLM_REJECT_SAMPLE_OPT:
                draft_prob = logits.softmax(dim=-1, dtype=torch.float32)
                draft_probs_list.append(draft_prob)

630
631
        # [batch_size, num_speculative_tokens]
        draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
王敏's avatar
王敏 committed
632
633
634
635
636

        if envs.VLLM_REJECT_SAMPLE_OPT:
            draft_probs = torch.stack(draft_probs_list, dim=1).contiguous()
            return draft_token_ids, draft_probs

637
        return draft_token_ids
638

639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
    def set_inputs_first_pass(
        self,
        target_token_ids: torch.Tensor,
        next_token_ids: torch.Tensor,
        target_positions: torch.Tensor,
        last_token_indices: torch.Tensor | None,
        cad: CommonAttentionMetadata,
        num_rejected_tokens_gpu: torch.Tensor | None,
    ) -> tuple[int, torch.Tensor, CommonAttentionMetadata]:
        if last_token_indices is None:
            last_token_indices = cad.query_start_loc[1:] - 1

        num_tokens = target_token_ids.shape[0]
        # Shift the input ids by one token.
        # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
        self.input_ids[: num_tokens - 1] = target_token_ids[1:]
        # Replace the last token with the next token.
        # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
        self.input_ids[last_token_indices] = next_token_ids

        # copy inputs to buffer for cudagraph
        self._set_positions(num_tokens, target_positions)

        return num_tokens, last_token_indices, cad

    def model_returns_tuple(self) -> bool:
        return self.method not in ("mtp", "draft_model")

667
    def prepare_next_token_ids_cpu(
668
        self,
669
        sampled_token_ids: list[list[int]],
670
671
672
673
        requests: dict[str, CachedRequestState],
        gpu_input_batch: InputBatch,
        num_scheduled_tokens: dict[str, int],
    ) -> torch.Tensor:
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
        """
        This function is used to prepare the inputs for speculative decoding.
        It calculates the next token ids for each request based on the sampled
        token ids from the CPU. If a request has no sampled token ids (e.g.,
        during the initial decoding steps), it falls back to using the request
        state to get the next token id.
        """
        req_ids = gpu_input_batch.req_ids
        next_token_ids: list[int] = []
        for i, token_ids in enumerate(sampled_token_ids):
            if token_ids:
                # Common case.
                next_token_id = token_ids[-1]
            else:
                # Partial prefill (rare case).
                # Get the next token id from the request state.
                req_id = req_ids[i]
                req_state = requests[req_id]
692
                seq_len = req_state.num_computed_tokens + num_scheduled_tokens[req_id]
693
694
                next_token_id = req_state.get_token_id(seq_len)
            next_token_ids.append(next_token_id)
695
        next_token_ids = torch.tensor(
696
697
            next_token_ids, dtype=torch.int32, device=self.input_ids.device
        )
698
699
        return next_token_ids

700
701
702
703
704
705
    def prepare_next_token_ids_padded(
        self,
        common_attn_metadata: CommonAttentionMetadata,
        sampled_token_ids: torch.Tensor,
        requests: dict[str, CachedRequestState],
        gpu_input_batch: InputBatch,
706
        discard_request_mask: torch.Tensor,
707
    ) -> tuple[torch.Tensor, torch.Tensor]:
708
709
710
711
        """
        This function is used to prepare the inputs for speculative decoding.
        It calculates the next token ids and the number of valid sampled tokens
        for each request, considering the "discarded" requests whose next token
712
713
        is not sampled and comes from `request.get_token_id()` instead. This is denoted
        the "backup" token id. It also counts rejected tokens via `sampled_token_ids`.
714
715
716
        """
        # Precompute get_token_id for when there is no valid next token
        num_reqs = gpu_input_batch.num_reqs
717
718
719
720
721
722
        self.backup_next_token_ids.np[:num_reqs] = np.array(
            [
                requests[gpu_input_batch.req_ids[i]].get_token_id(
                    common_attn_metadata.seq_lens_cpu[i].item()
                )
                for i in range(num_reqs)
723
724
            ],
            dtype=np.int32,
725
        )
726
        self.backup_next_token_ids.copy_to_gpu(num_reqs)
727
        backup_tokens_gpu = self.backup_next_token_ids.gpu
728

729
730
        batch_size, num_tokens = sampled_token_ids.shape
        device = sampled_token_ids.device
731

732
733
        assert discard_request_mask.dtype == torch.bool
        assert backup_tokens_gpu.dtype == torch.int32
734

735
736
        next_token_ids = torch.empty(batch_size, dtype=torch.int32, device=device)
        valid_sampled_tokens_count = next_token_ids.new_empty(batch_size)
737

738
739
        # Kernel grid: one program per request (row)
        grid = (batch_size,)
740

741
742
743
744
745
746
747
748
749
750
751
752
753
        # Find the next power of 2 for block sizes
        BLOCK_SIZE_TOKENS = triton.next_power_of_2(num_tokens)
        eagle_prepare_next_token_padded_kernel[grid](
            sampled_token_ids,
            discard_request_mask,
            backup_tokens_gpu,
            next_token_ids,
            valid_sampled_tokens_count,
            gpu_input_batch.vocab_size,
            num_tokens,
            batch_size,
            sampled_token_ids.stride(0),
            BLOCK_SIZE_TOKENS=BLOCK_SIZE_TOKENS,
754
        )
755
756
757

        return next_token_ids, valid_sampled_tokens_count

758
759
760
761
762
    def prepare_inputs_padded(
        self,
        common_attn_metadata: CommonAttentionMetadata,
        spec_decode_metadata: SpecDecodeMetadata,
        valid_sampled_tokens_count: torch.Tensor,
763
    ) -> tuple[CommonAttentionMetadata, torch.Tensor, torch.Tensor]:
764
765
766
767
768
769
        """
        This function is used to prepare the inputs for speculative decoding
        It updates the common_attn_metadata for speculative decoding,
        but does not consider the rejected tokens. Instead, all tokens
        are included as inputs to the speculator, with the rejected tokens
        used as padding and filtered out later by `token_indices_to_sample`.
770
        No blocking CPU operations should be introduced in this function.
771
        """
772
773
        num_reqs = common_attn_metadata.num_reqs
        device = valid_sampled_tokens_count.device
774

775
776
        token_indices_to_sample = torch.empty(
            (num_reqs,), dtype=torch.int32, device=device
777
        )
778
779
780
        num_rejected_tokens_gpu = torch.empty(
            (num_reqs,), dtype=torch.int32, device=device
        )
781

782
783
784
785
786
787
        grid = (num_reqs,)
        eagle_prepare_inputs_padded_kernel[grid](
            spec_decode_metadata.cu_num_draft_tokens,
            valid_sampled_tokens_count,
            common_attn_metadata.query_start_loc,
            token_indices_to_sample,
788
            num_rejected_tokens_gpu,
789
            num_reqs,
790
        )
791
792

        query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
793
        new_query_len_per_req = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
794
795
796
797
798
799
800

        total_num_tokens = query_start_loc_cpu[-1].item()

        spec_common_attn_metadata = CommonAttentionMetadata(
            query_start_loc=common_attn_metadata.query_start_loc,
            seq_lens=common_attn_metadata.seq_lens,
            query_start_loc_cpu=query_start_loc_cpu,
801
802
            _seq_lens_cpu=common_attn_metadata._seq_lens_cpu,
            _num_computed_tokens_cpu=common_attn_metadata._num_computed_tokens_cpu,
803
804
805
806
807
            num_reqs=common_attn_metadata.num_reqs,
            num_actual_tokens=total_num_tokens,
            max_query_len=new_query_len_per_req.max().item(),
            max_seq_len=common_attn_metadata.seq_lens_cpu.max().item(),
            block_table_tensor=common_attn_metadata.block_table_tensor,
808
            slot_mapping=common_attn_metadata.slot_mapping[:total_num_tokens],
809
            causal=True,
810
            dcp_local_seq_lens=common_attn_metadata.dcp_local_seq_lens,
811
812
        )

813
814
815
816
817
        return (
            spec_common_attn_metadata,
            token_indices_to_sample,
            num_rejected_tokens_gpu,
        )
818

819
820
821
822
823
824
825
826
827
828
    def propose_tree(
        self,
        batch_size: int,
        # [num_tokens, vocab_size]
        logits: torch.Tensor,
        # [num_tokens]
        positions: torch.Tensor,
        # [num_tokens, hidden_size]
        hidden_states: torch.Tensor,
        common_attn_metadata: CommonAttentionMetadata,
829
830
831
        slot_mappings: dict[str, torch.Tensor]
        | list[dict[str, torch.Tensor]]
        | None = None,
832
    ) -> list[torch.Tensor]:
833
834
835
836
        tree_attn_metadata_builder = self.runner.attn_groups[0][
            0
        ].get_metadata_builder()
        assert isinstance(tree_attn_metadata_builder, TreeAttentionMetadataBuilder)
837

838
        total_num_drafts = self.cu_drafts_per_level[0]
839
840
        level_num_drafts = total_num_drafts
        # Sample a draft token for each child at the tree root level.
841
        num_children = self.child_drafts_per_level[0]
842
843
844
        if num_children == 1:
            draft_token_ids = logits.argmax(dim=-1).view(batch_size, -1)
        else:
845
846
847
            draft_token_ids = torch.topk(logits, num_children, dim=-1).indices.view(
                batch_size, -1
            )
848
849
850
851
        draft_token_ids_list = [draft_token_ids]
        draft_hidden_states = hidden_states.view(batch_size, 1, -1)

        # Initialize empty tensors for concatenation with the level outputs.
852
853
854
855
856
857
858
859
860
        tree_input_ids = torch.empty(
            0, device=self.input_ids.device, dtype=self.input_ids.dtype
        )
        tree_positions = torch.empty(
            0, device=self.positions.device, dtype=self.positions.dtype
        )
        tree_hidden_states = torch.empty(
            0, device=self.hidden_states.device, dtype=self.hidden_states.dtype
        )
861
862
        # Precompute the draft token positions.
        flattened_draft_positions = (
863
864
            positions.view(batch_size, -1) + self.tree_draft_pos_offsets[:batch_size, :]
        )
865
        tree_depth = len(self.cu_drafts_per_level)
866
        for level in range(tree_depth - 1):
867
868
            # Get draft positions for RoPE.
            draft_positions = positions + (level + 1)
869
            exceeds_max_model_len = (positions + total_num_drafts) >= self.max_model_len
870
871
            # Mask out the position ids that exceed the max model length.
            # Otherwise, we may get out-of-range error in RoPE.
872
            draft_positions = torch.where(
873
874
875
                exceeds_max_model_len,
                0,
                draft_positions,
876
877
            ).view(batch_size, -1)

878
879
            if level_num_drafts > 1:
                # Repeat the positions for each draft at this level.
880
                draft_positions = draft_positions.repeat_interleave(
881
882
                    level_num_drafts, dim=1
                )
883
884
885
886

            if num_children > 1:
                # Repeat draft hidden states for each child.
                draft_hidden_states = draft_hidden_states.repeat_interleave(
887
888
                    num_children, dim=1
                )
889
890

            # Concatenate the draft tokens, positions, and hidden states.
891
892
            tree_input_ids = torch.cat([tree_input_ids, draft_token_ids], dim=1)
            tree_positions = torch.cat([tree_positions, draft_positions], dim=1)
893
            tree_hidden_states = torch.cat(
894
895
                [tree_hidden_states, draft_hidden_states], dim=1
            )
896
897
898

            # Build new attention metadata for the next level of drafts.
            # This is necessary to support tree attention.
899
            query_len = total_num_drafts
900
901
            common_attn_metadata = replace(
                common_attn_metadata,
902
                query_start_loc=query_len * self.arange[: batch_size + 1],
903
904
905
906
907
                seq_lens=common_attn_metadata.seq_lens + level_num_drafts,
                num_actual_tokens=batch_size * query_len,
                max_query_len=query_len,
            )
            attn_metadata = tree_attn_metadata_builder.build_for_drafting(
908
                common_attn_metadata=common_attn_metadata, draft_index=level + 1
909
910
911
912
913
914
915
916
            )

            # Apply new attention metadata to all layers.
            per_layer_attn_metadata = {}
            for layer_name in self.attn_layer_names:
                per_layer_attn_metadata[layer_name] = attn_metadata

            # Consider max model length.
917
918
919
            attn_metadata.max_seq_len = min(
                attn_metadata.max_seq_len, self.max_model_len
            )
920
921
922
923
924
            # For the requests that exceed the max model length, we set the
            # sequence length to 1 to minimize their overheads in attention.
            attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1)

            # Compute the slot mapping.
925
            block_size = tree_attn_metadata_builder.kv_cache_spec.block_size
926
            query_positions = flattened_draft_positions[:, level : level + query_len]
927
            block_numbers = query_positions // block_size
928
            block_ids = attn_metadata.block_table.gather(dim=1, index=block_numbers)
929
            slot_mapping = block_ids * block_size + query_positions % block_size
930
931
932
933
934
935
936
937
938
939
940
            # Mask out the slot mappings that exceed the max model length.
            # Otherwise, the KV cache will be inadvertently updated with the
            # padding tokens.
            slot_mapping[exceeds_max_model_len] = PADDING_SLOT_ID
            attn_metadata.slot_mapping = slot_mapping.view(-1)

            # Copy inputs to buffer for cudagraph.
            num_tokens = attn_metadata.num_actual_tokens
            input_ids = tree_input_ids.view(-1)
            self.input_ids[:num_tokens] = input_ids
            self.positions[:num_tokens] = tree_positions.view(-1)
941
            self.hidden_states[:num_tokens] = tree_hidden_states.view(num_tokens, -1)
942

943
944
945
946
            cudagraph_runtime_mode, batch_desc = self.cudagraph_dispatcher.dispatch(
                num_tokens
            )
            num_input_tokens = batch_desc.num_tokens
947
            # Run the model.
948
            with set_forward_context(
949
950
951
952
                per_layer_attn_metadata,
                self.vllm_config,
                num_tokens=num_input_tokens,
                cudagraph_runtime_mode=cudagraph_runtime_mode,
953
954
955
                slot_mapping=self._get_slot_mapping(
                    num_input_tokens, attn_metadata.slot_mapping
                ),
956
            ):
957
958
959
960
961
962
963
964
965
                last_hidden_states, hidden_states = self.model(
                    input_ids=self.input_ids[:num_input_tokens],
                    positions=self.positions[:num_input_tokens],
                    hidden_states=self.hidden_states[:num_input_tokens],
                    inputs_embeds=None,
                )

            # Get the output hidden states for the draft tokens.
            draft_hidden_states = hidden_states[:num_tokens].view(
966
967
                batch_size, query_len, -1
            )[:, -level_num_drafts:]
968
            draft_last_hidden_states = last_hidden_states[:num_tokens].view(
969
970
                batch_size, query_len, -1
            )[:, -level_num_drafts:]
971
972
973

            # Get the output logits for the draft tokens.
            logits = self.model.compute_logits(
974
975
                draft_last_hidden_states.reshape(batch_size * level_num_drafts, -1)
            )
976
977
978
979
980
981

            # Sample a draft token for each child at the next tree level.
            num_children = self.child_drafts_per_level[level + 1]
            if num_children == 1:
                draft_token_ids = logits.argmax(dim=-1).view(batch_size, -1)
            else:
982
983
984
                draft_token_ids = torch.topk(logits, num_children, dim=-1).indices.view(
                    batch_size, -1
                )
985
986
987
            draft_token_ids_list.append(draft_token_ids)

            # Update the # drafts counters for the next tree level.
988
            level_num_drafts = self.cu_drafts_per_level[level + 1] - total_num_drafts
989
990
991
            total_num_drafts = self.cu_drafts_per_level[level + 1]
        return draft_token_ids_list

992
    def prepare_inputs(
993
994
        self,
        common_attn_metadata: CommonAttentionMetadata,
995
996
        sampled_token_ids: list[list[int]],
        num_draft_tokens: list[int],
997
998
    ) -> tuple[CommonAttentionMetadata, torch.Tensor]:
        """
999
        This function is used to prepare the inputs for speculative decoding.
1000
1001
1002
1003
1004
1005
        It updates to the common_attn_metadata to account for the rejected
        tokens (and newly sampled tokens). It also returns the token indices
        of the tokens that should be fed to the speculator.
        """
        # E.g.
        #  common_attn_metadata.query_start_loc{_cpu}:
1006
        #       [0, q1, q1 + q2, q1 + q2 + q3]
1007
1008
1009
1010
1011
1012
        #  common_attn_metadata.seq_lens{_cpu}: [s1, s2, s3]
        #  num_rejected_tokens: [n1, n2, n3]
        # This function computes the intermediate values:
        #  num_tokens_per_req: [q1 - n1, q2 - n2, q3 - n3]
        # And returns:
        #  common_attn_metadata.query_start_loc{_cpu}:
1013
        #       [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
1014
        #  common_attn_metadata.seq_lens{_cpu}:
1015
        #       [s1 - n1 + 1, s2 - n2 + 1, s3 - n3 + 1]
1016
        #  token_indices: [0, 1, ..., q1 - n1 - 1,
1017
1018
        #                 q1, q1 + 1, ..., q1 + q2 - n2 - 1,
        #                 q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1]
1019

1020
1021
1022
1023
        num_rejected_tokens = [
            n + 1 - len(sampled_token_ids[i]) if n > 0 else 0
            for i, n in enumerate(num_draft_tokens)
        ]
1024
        num_rejected_tokens = torch.tensor(num_rejected_tokens, dtype=torch.int32)
1025

1026
1027
        device = common_attn_metadata.query_start_loc.device
        query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
1028
        new_seq_lens_cpu = common_attn_metadata.seq_lens_cpu - num_rejected_tokens
1029
1030

        # [0, q1, q1 + q2, q1 + q2 + q3] -> [q1, q2, q3]
1031
        new_query_len_per_req = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
1032
1033
1034
1035
1036
1037
1038
1039
        # [q1, q2, q3] -> [q1 - n1, q2 - n2, q3 - n3]
        new_num_tokens_per_req = new_query_len_per_req - num_rejected_tokens
        new_num_tokens_per_req_np = new_num_tokens_per_req.numpy()

        # [q1 - n1, q2 - n2, q3 - n3] ->
        # [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
        new_query_start_loc_cpu = torch.zeros(
            query_start_loc_cpu.shape,
1040
            dtype=torch.int32,
1041
1042
            pin_memory=is_pin_memory_available(),
        )
1043
1044
1045
1046
1047
1048
1049
1050
1051
        new_query_start_loc_np = new_query_start_loc_cpu.numpy()
        np.cumsum(new_num_tokens_per_req_np, out=new_query_start_loc_np[1:])

        total_num_tokens = new_query_start_loc_np[-1]
        # Example assuming num_tokens_per_req_np = [2, 4, 3]
        # this implies that `new_query_start_locs` is:
        # [0, 2, 6, 9] ->
        # [0, 0, 2, 2, 2, 2, 6, 6, 6]
        #  _r1_  ____r2____  ___r3__
1052
1053
1054
        new_query_start_locs_expanded = np.repeat(
            new_query_start_loc_np[:-1], new_num_tokens_per_req_np
        )
1055
1056
1057
        # [0, 1, 2, 3, 4, 5, 6, 7, 8] ->
        # [0, 1, 0, 1, 2, 3, 0, 1, 2]
        #  _r1_  ____r2____  ___r3__
1058
1059
1060
        token_offests = (
            self.token_arange_np[:total_num_tokens] - new_query_start_locs_expanded
        )
1061
1062
1063
1064
1065
1066

        # Expand starting positions to match token pattern
        # [0, q1, q1 + q2] ->
        # [0, 0, q1, q1, q1, q1, q1 + q2, q1 + q2, q1 + q2]
        #  _r1_  _____r2_______  ___________r3____________
        old_query_start_locs_expanded = np.repeat(
1067
1068
            query_start_loc_cpu[:-1].numpy(), new_num_tokens_per_req_np
        )
1069
        # Final token indices are:
1070
1071
1072
        # [0, 1,                                // req 1
        #  q1 + 0, q1 + 1, q1 + 2, q1 + 3,       // req 2
        #  q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3
1073
        token_indices_np = token_offests + old_query_start_locs_expanded
1074
        token_indices = torch.from_numpy(token_indices_np).to(device, non_blocking=True)
1075
1076

        spec_common_attn_metadata = CommonAttentionMetadata(
1077
            query_start_loc=new_query_start_loc_cpu.to(device, non_blocking=True),
1078
1079
            seq_lens=new_seq_lens_cpu.to(device, non_blocking=True),
            query_start_loc_cpu=new_query_start_loc_cpu,
1080
1081
            _seq_lens_cpu=new_seq_lens_cpu,
            _num_computed_tokens_cpu=common_attn_metadata._num_computed_tokens_cpu,
1082
1083
1084
            num_reqs=common_attn_metadata.num_reqs,
            num_actual_tokens=total_num_tokens,
            max_query_len=new_query_len_per_req.max().item(),
1085
            max_seq_len=new_seq_lens_cpu.max().item(),
1086
1087
            block_table_tensor=common_attn_metadata.block_table_tensor,
            slot_mapping=common_attn_metadata.slot_mapping[token_indices],
1088
            causal=True,
1089
            dcp_local_seq_lens=common_attn_metadata.dcp_local_seq_lens,
1090
        )
1091
1092

        return spec_common_attn_metadata, token_indices
1093

1094
    def get_model_name(self, model: nn.Module) -> str:
1095
        if hasattr(model, "module"):  # multi-GPU
1096
1097
1098
            model = model.module
        return model.__class__.__name__

1099
    def load_model(self, target_model: nn.Module) -> None:
1100
        draft_model_config = self.vllm_config.speculative_config.draft_model_config
1101
        target_attn_layer_names = set(
1102
            get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase).keys()
1103
        )
1104
1105
        # FIXME: support hybrid kv for draft model
        target_indexer_layer_names = set(
1106
1107
1108
1109
            get_layers_from_vllm_config(
                self.vllm_config, DeepseekV32IndexerCache
            ).keys()
        )
1110

1111
        from vllm.compilation.backends import set_model_tag
1112

1113
        with set_model_tag("eagle_head"):
1114
1115
1116
            self.model = get_model(
                vllm_config=self.vllm_config, model_config=draft_model_config
            )
1117

1118
        draft_attn_layer_names = (
1119
            get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase).keys()
1120
1121
1122
1123
1124
1125
            - target_attn_layer_names
        )
        indexer_layers = get_layers_from_vllm_config(
            self.vllm_config, DeepseekV32IndexerCache
        )
        draft_indexer_layer_names = indexer_layers.keys() - target_indexer_layer_names
1126
        self.attn_layer_names = list(draft_attn_layer_names - draft_indexer_layer_names)
1127
1128
1129
1130
1131
        self.indexer_layer_names = list(draft_indexer_layer_names)

        if self.indexer_layer_names:
            first_layer = self.indexer_layer_names[0]
            self.draft_indexer_metadata_builder = (
1132
1133
1134
                indexer_layers[first_layer]
                .get_attn_backend()
                .get_builder_cls()(
1135
                    indexer_layers[first_layer].get_kv_cache_spec(self.vllm_config),
1136
1137
1138
                    self.indexer_layer_names,
                    self.vllm_config,
                    self.device,
1139
1140
                )
            )
1141
1142
        else:
            self.draft_indexer_metadata_builder = None
1143

1144
        if self.supports_mm_inputs:
1145
1146
1147
            # Even if the target model is multimodal, we can also use
            # text-only draft models
            try:
1148
                dummy_input_ids = torch.tensor([[1]], device=self.input_ids.device)
1149
                self.model.embed_input_ids(dummy_input_ids, multimodal_embeddings=None)
1150
1151
1152
            except (NotImplementedError, AttributeError, TypeError):
                logger.warning(
                    "Draft model does not support multimodal inputs, "
1153
1154
                    "falling back to text-only mode"
                )
1155
                self.supports_mm_inputs = False
1156

1157
1158
        if supports_multimodal(target_model):
            # handle multimodality
1159
1160
1161
            if self.get_model_name(target_model) in [
                "Qwen2_5_VLForConditionalGeneration",
                "Qwen3VLForConditionalGeneration",
1162
                "Qwen3VLMoeForConditionalGeneration",
1163
                "GlmOcrForConditionalGeneration",
Rayyyyy's avatar
Rayyyyy committed
1164
1165
                "Qwen3_5ForConditionalGeneration",
                "Qwen3_5MoeForConditionalGeneration",
1166
            ]:
1167
                self.model.config.image_token_index = target_model.config.image_token_id
1168
1169
1170
1171
            elif self.get_model_name(target_model) == "PixtralForConditionalGeneration":
                self.model.config.image_token_index = (
                    target_model.config.vision_config.image_token_id
                )
1172
1173
            else:
                self.model.config.image_token_index = (
1174
1175
                    target_model.config.image_token_index
                )
1176
1177
1178
            target_language_model = target_model.get_language_model()
        else:
            target_language_model = target_model
1179

1180
        # share embed_tokens with the target model if needed
1181
        if get_pp_group().world_size == 1:
1182
            if hasattr(target_language_model.model, "embed_tokens"):
1183
                target_embed_tokens = target_language_model.model.embed_tokens
1184
            elif hasattr(target_language_model.model, "embedding"):
1185
1186
1187
                target_embed_tokens = target_language_model.model.embedding
            else:
                raise AttributeError(
1188
1189
                    "Target model does not have 'embed_tokens' or 'embedding' attribute"
                )
1190

1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
            share_embeddings = False
            if hasattr(self.model, "has_own_embed_tokens"):
                # EAGLE model
                if not self.model.has_own_embed_tokens:
                    share_embeddings = True
                    logger.info(
                        "Detected EAGLE model without its own embed_tokens in the"
                        " checkpoint. Sharing target model embedding weights with the"
                        " draft model."
                    )
                elif (
                    isinstance(target_embed_tokens.weight, torch.Tensor)
                    and isinstance(self.model.model.embed_tokens.weight, torch.Tensor)
1204
1205
1206
                    # TODO: Offload to CPU for comparison to avoid extra GPU memory
                    # usage in CI testing environments with limited GPU memory
                    and torch.equal(
1207
1208
                        target_embed_tokens.weight.cpu(),
                        self.model.model.embed_tokens.weight.cpu(),
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
                    )
                ):
                    share_embeddings = True
                    logger.info(
                        "Detected EAGLE model with embed_tokens identical to the target"
                        " model. Sharing target model embedding weights with the draft"
                        " model."
                    )
                else:
                    logger.info(
                        "Detected EAGLE model with distinct embed_tokens weights. "
                        "Keeping separate embedding weights from the target model."
                    )
1222
            else:
1223
1224
                # MTP model
                share_embeddings = True
1225
                logger.info(
1226
1227
                    "Detected MTP model. "
                    "Sharing target model embedding weights with the draft model."
1228
                )
1229
1230
1231
1232
1233

            if share_embeddings:
                if hasattr(self.model.model, "embed_tokens"):
                    del self.model.model.embed_tokens
                self.model.model.embed_tokens = target_embed_tokens
1234
        else:
1235
            logger.info(
1236
                "The draft model's vocab embedding will be loaded separately"
1237
1238
                " from the target model."
            )
1239
1240

        # share lm_head with the target model if needed
1241
1242
1243
1244
1245
        share_lm_head = False
        if hasattr(self.model, "has_own_lm_head"):
            # EAGLE model
            if not self.model.has_own_lm_head:
                share_lm_head = True
1246
                logger.info(
1247
1248
1249
1250
1251
1252
1253
                    "Detected EAGLE model without its own lm_head in the checkpoint. "
                    "Sharing target model lm_head weights with the draft model."
                )
            elif (
                hasattr(target_language_model, "lm_head")
                and isinstance(target_language_model.lm_head.weight, torch.Tensor)
                and isinstance(self.model.lm_head.weight, torch.Tensor)
1254
1255
                # TODO: Offload to CPU for comparison to avoid extra GPU memory
                # usage in CI testing environments with limited GPU memory
1256
                and torch.equal(
1257
1258
                    target_language_model.lm_head.weight.cpu(),
                    self.model.lm_head.weight.cpu(),
1259
                )
1260
            ):
1261
                share_lm_head = True
1262
                logger.info(
1263
1264
                    "Detected EAGLE model with lm_head identical to the target model. "
                    "Sharing target model lm_head weights with the draft model."
1265
                )
1266
1267
            else:
                logger.info(
1268
1269
                    "Detected EAGLE model with distinct lm_head weights. "
                    "Keeping separate lm_head weights from the target model."
1270
                )
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
        else:
            # MTP model
            share_lm_head = True
            logger.info(
                "Detected MTP model. "
                "Sharing target model lm_head weights with the draft model."
            )

        if share_lm_head and hasattr(target_language_model, "lm_head"):
            if hasattr(self.model, "lm_head"):
                del self.model.lm_head
            self.model.lm_head = target_language_model.lm_head
1283

王敏's avatar
王敏 committed
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
            # MTP models call compute_logits via shared_head.head (a
            # ParallelLMHead inside each MTP layer), not self.model.lm_head.
            # If the checkpoint omits a copy of the lm_head weights at the
            # MTP layer path, shared_head.head stays uninitialised and
            # produces NaN logits. Always share it explicitly.
            inner = getattr(self.model, "model", None)
            layers = getattr(inner, "layers", None) if inner else None
            if layers is not None:
                items = layers.values() if isinstance(layers, nn.ModuleDict) else layers
                for layer in items:
                    sh = getattr(layer, "shared_head", None)
                    if sh is not None and hasattr(sh, "head"):
                        del sh.head
                        sh.head = target_language_model.lm_head
                        logger.info(
                            "Shared target model lm_head with MTP shared_head.head."
                        )

1302
1303
1304
1305
    @torch.inference_mode()
    def dummy_run(
        self,
        num_tokens: int,
1306
1307
        use_cudagraphs: bool = True,
        is_graph_capturing: bool = False,
1308
        slot_mappings: dict[str, torch.Tensor] | None = None,
1309
    ) -> None:
Rémi Delacourt's avatar
Rémi Delacourt committed
1310
1311
1312
1313
        # FIXME: when using tree-based specdec, adjust number of forward-passes
        # according to the depth of the tree.
        for fwd_idx in range(
            self.num_speculative_tokens if not is_graph_capturing else 1
1314
        ):
Rémi Delacourt's avatar
Rémi Delacourt committed
1315
1316
            if fwd_idx <= 1:
                num_tokens_dp_padded, num_tokens_across_dp = self._pad_batch_across_dp(
1317
                    num_tokens_unpadded=num_tokens, num_tokens_padded=num_tokens
Rémi Delacourt's avatar
Rémi Delacourt committed
1318
                )
1319
1320
1321
                if use_cudagraphs:
                    cudagraph_runtime_mode, batch_desc = (
                        self.cudagraph_dispatcher.dispatch(num_tokens_dp_padded)
Rémi Delacourt's avatar
Rémi Delacourt committed
1322
                    )
1323
                    num_input_tokens = batch_desc.num_tokens
Rémi Delacourt's avatar
Rémi Delacourt committed
1324
                else:
1325
                    cudagraph_runtime_mode = CUDAGraphMode.NONE
Rémi Delacourt's avatar
Rémi Delacourt committed
1326
1327
1328
                    num_input_tokens = num_tokens_dp_padded
                if num_tokens_across_dp is not None:
                    num_tokens_across_dp[self.dp_rank] = num_input_tokens
1329

1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
            # Make sure to use EAGLE's own buffer during cudagraph capture.
            if (
                self.attn_layer_names
                and slot_mappings is not None
                and self.attn_layer_names[0] in slot_mappings
            ):
                slot_mapping_dict = self._get_slot_mapping(num_input_tokens)
            else:
                slot_mapping_dict = slot_mappings or {}

Rémi Delacourt's avatar
Rémi Delacourt committed
1340
1341
1342
1343
1344
            with set_forward_context(
                None,
                self.vllm_config,
                num_tokens=num_input_tokens,
                num_tokens_across_dp=num_tokens_across_dp,
1345
                cudagraph_runtime_mode=cudagraph_runtime_mode,
1346
                slot_mapping=slot_mapping_dict,
Rémi Delacourt's avatar
Rémi Delacourt committed
1347
1348
1349
1350
1351
1352
1353
            ):
                if self.supports_mm_inputs:
                    input_ids = None
                    inputs_embeds = self.inputs_embeds[:num_input_tokens]
                else:
                    input_ids = self.input_ids[:num_input_tokens]
                    inputs_embeds = None
1354

1355
                kwargs = dict(
Rémi Delacourt's avatar
Rémi Delacourt committed
1356
1357
1358
1359
                    input_ids=input_ids,
                    positions=self._get_positions(num_input_tokens),
                    inputs_embeds=inputs_embeds,
                )
1360
1361
1362
                if self.pass_hidden_states_to_model:
                    kwargs["hidden_states"] = self.hidden_states[:num_input_tokens]
                self.model(**kwargs)
1363

1364
    def _get_attention_metadata_builder(self) -> AttentionMetadataBuilder:
1365
        """Find and return the attention metadata builders for EAGLE layers.
1366

1367
1368
        Returns:
            The metadata builders for EAGLE layers.
1369

1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
        Raises:
            AssertionError: If no metadata builders are found for EAGLE layers.
        """
        builder = None
        chosen_layer = self.attn_layer_names[0]

        for kv_cache_group in self.runner.attn_groups:
            for attn_group in kv_cache_group:
                if chosen_layer in attn_group.layer_names:
                    builder = attn_group.get_metadata_builder()
                    break
            if builder is not None:
                break

        assert builder is not None, (
1385
1386
            "Failed to find attention metadata builder for EAGLE layers."
        )
1387
1388
        return builder

1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
    def _get_eagle3_use_aux_hidden_state_from_config(self) -> bool:
        """
        Some eagle3 heads (e.g., nvidia/gpt-oss-120b-Eagle3-v2) do not use auxiliary
        hidden states and directly uses the last layer output just like eagle1.
        They might indicate this by setting "use_aux_hidden_state" to False
        inside the "eagle_config" dict of their hf_config.
        """
        if self.method != "eagle3":
            return False
        # Assume that eagle3 heads use aux hidden states by default
        use_aux_hidden_state = True
        eagle_config = getattr(self.draft_model_config.hf_config, "eagle_config", None)
        if eagle_config is not None:
            use_aux_hidden_state = eagle_config.get("use_aux_hidden_state", True)
        return use_aux_hidden_state

1405
    def validate_same_kv_cache_group(self, kv_cache_config: KVCacheConfig) -> None:
1406
        """
1407
1408
        Validate that all drafting layers belong to the same KVCacheGroup.
        Need this assumption to ensure all drafting layers can use the
1409
1410
1411
1412
1413
1414
1415
        same AttentionMetadata.
        May extend to multiple AttentionMetadata in the future.
        """
        kv_cache_groups: dict[str, int] = {}
        for id, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups):
            for layer_name in kv_cache_group.layer_names:
                kv_cache_groups[layer_name] = id
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
        assert (
            len(
                set(
                    [
                        kv_cache_groups[layer_name]
                        for layer_name in self.attn_layer_names
                    ]
                )
            )
            == 1
1426
        ), "All drafting layers should belong to the same kv cache group"
1427

Rémi Delacourt's avatar
Rémi Delacourt committed
1428
1429
1430
1431
1432
1433
    def _pad_batch_across_dp(
        self,
        num_tokens_unpadded: int,
        num_tokens_padded: int,
    ) -> tuple[int, torch.Tensor]:
        # TODO(Flechman): support DBO ubatching
1434
        should_ubatch, num_toks_across_dp, _ = coordinate_batch_across_dp(
Rémi Delacourt's avatar
Rémi Delacourt committed
1435
1436
1437
            num_tokens_unpadded=num_tokens_unpadded,
            parallel_config=self.vllm_config.parallel_config,
            allow_microbatching=False,
1438
1439
            allow_dp_padding=self.cudagraph_dispatcher.cudagraph_mode
            != CUDAGraphMode.NONE,
Rémi Delacourt's avatar
Rémi Delacourt committed
1440
1441
1442
1443
            num_tokens_padded=num_tokens_padded,
            uniform_decode=None,
            num_scheduled_tokens_per_request=None,
        )
1444
        assert not should_ubatch, "DBO ubatching not implemented for EAGLE"
Rémi Delacourt's avatar
Rémi Delacourt committed
1445
1446
1447
1448
1449

        num_tokens_dp_padded = num_tokens_padded
        if num_toks_across_dp is not None:
            num_tokens_dp_padded = int(num_toks_across_dp[self.dp_rank].item())
        return num_tokens_dp_padded, num_toks_across_dp
1450
1451


1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
class EagleProposer(SpecDecodeBaseProposer):
    def __init__(
        self,
        vllm_config: VllmConfig,
        device: torch.device,
        runner=None,
    ):
        super().__init__(
            vllm_config,
            device,
            pass_hidden_states_to_model=True,
            runner=runner,
        )


1467
1468
1469
1470
# NOTE(woosuk): Currently, the below code is not used and we always use argmax
# to sample the draft tokens. We will use this after we find a way to manage
# the draft prob tensor.
# Refer to https://github.com/vllm-project/vllm/pull/16899 for the details.
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
# FIXME(woosuk): The logic here is duplicated with the main sampling code.
# We should refactor this to reuse the same sampling implementation.
def compute_probs_and_sample_next_token(
    logits: torch.Tensor,
    sampling_metadata: SamplingMetadata,
) -> tuple[torch.Tensor, torch.Tensor]:
    if sampling_metadata.all_greedy:
        # For greedy requests, draft_probs is not used in rejection sampling.
        # Therefore, we can just return the logits.
        probs = logits
        next_token_ids = logits.argmax(dim=-1)
        return next_token_ids, probs

1484
1485
1486
1487
1488
1489
1490
1491
1492
    assert sampling_metadata.temperature is not None

    # Use epsilon comparison to detect greedy sampling (temperature ~ 0.0)
    # consistent with sampler.py's _SAMPLING_EPS threshold
    temperature = sampling_metadata.temperature
    # Avoid division by zero if there are greedy requests.
    if not sampling_metadata.all_random:
        is_greedy = temperature < _SAMPLING_EPS
        temperature = torch.where(is_greedy, 1.0, temperature)
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
    logits.div_(temperature.view(-1, 1))
    probs = logits.softmax(dim=-1, dtype=torch.float32)

    # NOTE(woosuk): Currently, we ignore most of the sampling parameters in
    # generating the draft tokens. We only use the temperature. While this
    # could degrade the acceptance rate, it does not affect the distribution
    # of the generated tokens after rejection sampling.

    # TODO(woosuk): Consider seeds.
    q = torch.empty_like(probs)
    q.exponential_()
1504
1505
1506
    # NOTE(woosuk): We shouldn't use `probs.div_(q)` because the draft_probs
    # will be used later for rejection sampling.
    next_token_ids = probs.div(q).argmax(dim=-1).view(-1)
1507
1508
    if not sampling_metadata.all_random:
        greedy_token_ids = probs.argmax(dim=-1)
1509
        next_token_ids = torch.where(is_greedy, greedy_token_ids, next_token_ids)
1510
    return next_token_ids, probs