eagle.py 50 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
import ast
from dataclasses import replace
5
from importlib.util import find_spec
6

7
import numpy as np
8
9
10
import torch
import torch.nn as nn

11
from vllm.config import (
12
    CompilationMode,
13
14
15
16
    CUDAGraphMode,
    VllmConfig,
    get_layers_from_vllm_config,
)
17
from vllm.distributed.parallel_state import get_pp_group
18
from vllm.forward_context import set_forward_context
19
from vllm.logger import init_logger
20
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
21
from vllm.model_executor.model_loader import get_model
22
from vllm.model_executor.models import supports_multimodal
23
from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache
24
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
25
from vllm.multimodal import MULTIMODAL_REGISTRY
26
from vllm.platforms import current_platform
27
from vllm.utils.platform_utils import is_pin_memory_available
28
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
29
30
31
32
from vllm.v1.attention.backends.tree_attn import (
    TreeAttentionMetadata,
    TreeAttentionMetadataBuilder,
)
33
from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata
34
35
36
37
from vllm.v1.attention.backends.utils import (
    AttentionMetadataBuilder,
    CommonAttentionMetadata,
)
38
from vllm.v1.kv_cache_interface import KVCacheConfig
39
from vllm.v1.sample.metadata import SamplingMetadata
40
from vllm.v1.sample.sampler import _SAMPLING_EPS
41
42
43
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.utils import CpuGpuBuffer
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
44

45
46
logger = init_logger(__name__)

47
48
PADDING_SLOT_ID = -1

49
50
51
52
53
54

class EagleProposer:
    def __init__(
        self,
        vllm_config: VllmConfig,
        device: torch.device,
Jiayi Yao's avatar
Jiayi Yao committed
55
        runner=None,
56
57
    ):
        self.vllm_config = vllm_config
58
        self.speculative_config = vllm_config.speculative_config
59
        assert self.speculative_config is not None
60
61
        self.draft_model_config = self.speculative_config.draft_model_config
        self.method = self.speculative_config.method
62

Jiayi Yao's avatar
Jiayi Yao committed
63
        self.runner = runner
64
        self.device = device
65
        self.dtype = vllm_config.model_config.dtype
66
67
        self.max_model_len = vllm_config.model_config.max_model_len
        self.block_size = vllm_config.cache_config.block_size
68
69
        self.num_speculative_tokens = self.speculative_config.num_speculative_tokens
        self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
70
        self.token_arange_np = np.arange(self.max_num_tokens)
71
72
73
74
        # We need to get the hidden size from the draft model config because
        # the draft model's hidden size can be different from the target model's
        # hidden size (e.g., Llama 3.3 70B).
        self.hidden_size = self.draft_model_config.get_hidden_size()
75

76
77
78
        # Multi-modal data support
        self.mm_registry = MULTIMODAL_REGISTRY
        self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
79
80
            vllm_config.model_config
        )
81

82
83
        self.attn_metadata_builder: AttentionMetadataBuilder | None = None
        self.draft_indexer_metadata_builder: AttentionMetadataBuilder | None = None
84
85
        self.attn_layer_names: list[str] = []
        self.indexer_layer_names: list[str] = []
86

87
88
89
        self.use_cuda_graph = False

        compilation_config = self.vllm_config.compilation_config
90
        if compilation_config.mode == CompilationMode.VLLM_COMPILE:
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
            cudagraph_mode = compilation_config.cudagraph_mode
            if cudagraph_mode != CUDAGraphMode.NONE and not cudagraph_mode.has_mode(
                CUDAGraphMode.PIECEWISE
            ):
                logger.warning(
                    "Currently the eagle proposer only supports cudagraph_mode "
                    "PIECEWISE, if you want the drafter to use cuda graphs, "
                    "please set compilation_config.cudagraph_mode to PIECEWISE "
                    "or FULL_AND_PIECEWISE"
                )
            self.use_cuda_graph = (
                cudagraph_mode.has_mode(CUDAGraphMode.PIECEWISE)
                and not self.speculative_config.enforce_eager
            )

106
        self.cudagraph_batch_sizes = (
107
            (sorted(self.vllm_config.compilation_config.cudagraph_capture_sizes))
108
109
110
            if self.use_cuda_graph
            else []
        )
111

112
        self.use_cuda_graph = self.use_cuda_graph and bool(self.cudagraph_batch_sizes)
113
        # persistent buffers for cuda graph
114
115
116
        self.input_ids = torch.zeros(
            self.max_num_tokens, dtype=torch.int32, device=device
        )
117
118
119
        self.uses_mrope = self.vllm_config.model_config.uses_mrope
        if self.uses_mrope:
            # M-RoPE need (3, max_num_tokens)
120
121
122
            self.mrope_positions = torch.zeros(
                (3, self.max_num_tokens), dtype=torch.int64, device=device
            )
123
124
        else:
            # RoPE need (max_num_tokens,)
125
126
127
            self.positions = torch.zeros(
                self.max_num_tokens, dtype=torch.int64, device=device
            )
128
        self.hidden_states = torch.zeros(
129
130
            (self.max_num_tokens, self.hidden_size), dtype=self.dtype, device=device
        )
131

132
133
        # We need +1 here because the arange is used to set query_start_loc,
        # which has one more element than batch_size.
134
        max_batch_size = vllm_config.scheduler_config.max_num_seqs
135
        max_num_slots_for_arange = max(max_batch_size + 1, self.max_num_tokens)
136
137
138
        self.arange = torch.arange(
            max_num_slots_for_arange, device=device, dtype=torch.int32
        )
139

140
        self.inputs_embeds = torch.zeros(
141
142
            (self.max_num_tokens, self.hidden_size), dtype=self.dtype, device=device
        )
143

144
145
146
147
148
        self.backup_next_token_ids = CpuGpuBuffer(
            max_batch_size,
            dtype=torch.int32,
            pin_memory=is_pin_memory_available(),
            device=device,
149
150
            with_numpy=True,
        )
151

152
        # Determine allowed attention backends once during initialization.
153
154
        from vllm.attention.backends.registry import AttentionBackendEnum

155
        self.allowed_attn_types: tuple | None = None
156
157
        if current_platform.is_rocm():
            rocm_types = [TritonAttentionMetadata, FlashAttentionMetadata]
158
159
160
161
            # ROCM_AITER_FA is an optional backend
            if find_spec(
                AttentionBackendEnum.ROCM_AITER_FA.get_path(include_classname=False)
            ):
162
                from vllm.v1.attention.backends.rocm_aiter_fa import (
163
164
165
                    AiterFlashAttentionMetadata,
                )

166
167
168
                rocm_types.append(AiterFlashAttentionMetadata)
            self.allowed_attn_types = tuple(rocm_types)

169
170
        # Parse the speculative token tree.
        spec_token_tree = self.speculative_config.speculative_token_tree
171
        self.tree_choices: list[tuple[int, ...]] = ast.literal_eval(spec_token_tree)
172
173
174
175
176
177
178
179
        tree_depth = len(self.tree_choices[-1])
        # Precompute per-level properties of the tree.
        num_drafts_per_level = [0] * tree_depth
        for node in self.tree_choices:
            num_drafts_per_level[len(node) - 1] += 1
        self.cu_drafts_per_level = [num_drafts_per_level[0]]
        self.child_drafts_per_level = [num_drafts_per_level[0]]
        for level in range(1, tree_depth):
180
181
182
183
184
185
            self.cu_drafts_per_level.append(
                self.cu_drafts_per_level[-1] + num_drafts_per_level[level]
            )
            self.child_drafts_per_level.append(
                num_drafts_per_level[level] // num_drafts_per_level[level - 1]
            )
186
187
188
189
190
191
192
193
        # Precompute draft position offsets in flattened tree.
        self.tree_draft_pos_offsets = torch.arange(
            1,
            len(self.tree_choices) + 1,
            device=device,
            dtype=torch.int32,
        ).repeat(max_batch_size, 1)

194
195
196
197
198
199
200
201
202
203
204
    def _get_positions(self, num_tokens: int):
        if self.uses_mrope:
            return self.mrope_positions[:, :num_tokens]
        return self.positions[:num_tokens]

    def _set_positions(self, num_tokens: int, positions: torch.Tensor):
        if self.uses_mrope:
            self.mrope_positions[:, :num_tokens] = positions
        else:
            self.positions[:num_tokens] = positions

205
206
207
208
    def propose(
        self,
        # [num_tokens]
        target_token_ids: torch.Tensor,
209
        # [num_tokens] or [3, num_tokens] when M-RoPE is enabled
210
211
212
213
214
        target_positions: torch.Tensor,
        # [num_tokens, hidden_size]
        target_hidden_states: torch.Tensor,
        # [batch_size]
        next_token_ids: torch.Tensor,
215
        last_token_indices: torch.Tensor | None,
216
        common_attn_metadata: CommonAttentionMetadata,
217
        sampling_metadata: SamplingMetadata,
218
        mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None,
219
    ) -> torch.Tensor:
220
221
        num_tokens = target_token_ids.shape[0]
        batch_size = next_token_ids.shape[0]
222
223
224

        if last_token_indices is None:
            last_token_indices = common_attn_metadata.query_start_loc[1:] - 1
225

226
227
228
        if self.method == "eagle3":
            assert isinstance(self.model, Eagle3LlamaForCausalLM)
            target_hidden_states = self.model.combine_hidden_states(
229
230
                target_hidden_states
            )
231
            assert target_hidden_states.shape[-1] == self.hidden_size
232
233
        # Shift the input ids by one token.
        # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
234
        self.input_ids[: num_tokens - 1] = target_token_ids[1:]
235
236
        # Replace the last token with the next token.
        # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
237
        self.input_ids[last_token_indices] = next_token_ids
238

239
        assert self.runner is not None
Jiayi Yao's avatar
Jiayi Yao committed
240

241
242
243
244
245
        if self.attn_metadata_builder is None:
            attn_metadata_builder = self._get_attention_metadata_builder()
        else:
            attn_metadata_builder = self.attn_metadata_builder

246
        attn_metadata = attn_metadata_builder.build_for_drafting(
247
248
            common_attn_metadata=common_attn_metadata, draft_index=0
        )
249
250
251
252
253
254
        # FIXME: support hybrid kv for draft model (remove separate indexer)
        if self.draft_indexer_metadata_builder:
            draft_indexer_metadata = (
                self.draft_indexer_metadata_builder.build_for_drafting(
                    common_attn_metadata=common_attn_metadata,
                    draft_index=0,
255
256
                )
            )
257
258
        else:
            draft_indexer_metadata = None
259
260
261
262
263
        # At this moment, we assume all eagle layers belong to the same KV
        # cache group, thus using the same attention metadata.
        per_layer_attn_metadata = {}
        for layer_name in self.attn_layer_names:
            per_layer_attn_metadata[layer_name] = attn_metadata
264

265
266
267
268
        for layer_name in self.indexer_layer_names:
            assert draft_indexer_metadata is not None
            per_layer_attn_metadata[layer_name] = draft_indexer_metadata

269
        cudagraph_runtime_mode = CUDAGraphMode.NONE
270
        if self.use_cuda_graph and num_tokens <= self.cudagraph_batch_sizes[-1]:
271
            num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
272
            cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
273
274
275
        else:
            num_input_tokens = num_tokens
        # copy inputs to buffer for cudagraph
276
        self._set_positions(num_tokens, target_positions)
277
        self.hidden_states[:num_tokens] = target_hidden_states
278
279
280
281

        if self.supports_mm_inputs:
            mm_embeds, is_mm_embed = mm_embed_inputs or (None, None)

282
            self.inputs_embeds[:num_tokens] = self.model.embed_input_ids(
283
284
285
                self.input_ids[:num_tokens],
                multimodal_embeddings=mm_embeds,
                is_multimodal=is_mm_embed,
286
            )
287

288
            input_ids = None
289
            inputs_embeds = self.inputs_embeds[:num_input_tokens]
290
291
        else:
            input_ids = self.input_ids[:num_input_tokens]
292
            inputs_embeds = None
293

294
        with set_forward_context(
295
296
297
298
            per_layer_attn_metadata,
            self.vllm_config,
            num_tokens=num_input_tokens,
            cudagraph_runtime_mode=cudagraph_runtime_mode,
299
        ):
Jiayi Yao's avatar
Jiayi Yao committed
300
            ret_hidden_states = self.model(
301
                input_ids=input_ids,
302
                positions=self._get_positions(num_input_tokens),
303
304
                hidden_states=self.hidden_states[:num_input_tokens],
                inputs_embeds=inputs_embeds,
305
            )
306
            if self.method == "mtp":
Jiayi Yao's avatar
Jiayi Yao committed
307
                last_hidden_states = ret_hidden_states
308
                hidden_states = last_hidden_states
Jiayi Yao's avatar
Jiayi Yao committed
309
310
            else:
                last_hidden_states, hidden_states = ret_hidden_states
311
        sample_hidden_states = last_hidden_states[last_token_indices]
312
        logits = self.model.compute_logits(sample_hidden_states)
313
314
315
316
317
318

        # Early exit if there is only one draft token to be generated.
        if self.num_speculative_tokens == 1:
            draft_token_ids = logits.argmax(dim=-1)
            return draft_token_ids.view(-1, 1)

319
320
321
322
        if self.uses_mrope:
            positions = target_positions[:, last_token_indices]
        else:
            positions = target_positions[last_token_indices]
323
324
325
326
327
328
        if self.method in (
            "deepseek_mtp",
            "ernie_mtp",
            "longcat_flash_mtp",
            "pangu_ultra_moe_mtp",
        ):
XuruiYang's avatar
XuruiYang committed
329
330
331
            hidden_states = self.hidden_states[last_token_indices]
        else:
            hidden_states = hidden_states[last_token_indices]
332
333
334

        if isinstance(attn_metadata, TreeAttentionMetadata):
            # Draft using tree attention.
335
336
337
338
339
340
341
342
343
344
            draft_token_ids_list = self.propose_tree(
                batch_size=batch_size,
                logits=logits,
                positions=positions,
                hidden_states=hidden_states,
                common_attn_metadata=common_attn_metadata,
            )
            # [batch_size, num_tree_tokens]
            return torch.cat(draft_token_ids_list, dim=1)

345
        draft_token_ids = logits.argmax(dim=-1)
346

347
348
349
        if self.allowed_attn_types is not None and not isinstance(
            attn_metadata, self.allowed_attn_types
        ):
350
351
352
353
            raise ValueError(
                f"Unsupported attention metadata type for speculative "
                "decoding with num_speculative_tokens > 1: "
                f"{type(attn_metadata)}. Supported types are: "
354
355
                f"{self.allowed_attn_types}"
            )
356

357
358
359
        # Generate the remaining draft tokens.
        draft_token_ids_list = [draft_token_ids]

360
        if self.use_cuda_graph and batch_size <= self.cudagraph_batch_sizes[-1]:
361
            input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size)
362
            cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
363
364
        else:
            input_batch_size = batch_size
365
            cudagraph_runtime_mode = CUDAGraphMode.NONE
366
367
368

        common_attn_metadata.num_actual_tokens = batch_size
        common_attn_metadata.max_query_len = 1
369
        common_attn_metadata.query_start_loc = self.arange[: batch_size + 1]
370
        common_attn_metadata.query_start_loc_cpu = torch.from_numpy(
371
372
            self.token_arange_np[: batch_size + 1]
        ).clone()
373
        for token_index in range(self.num_speculative_tokens - 1):
374
            # Update the inputs.
375
376
377
            # cast to int32 is crucial when eagle model is compiled.
            # tensor.argmax() returns int64 by default.
            input_ids = draft_token_ids_list[-1].int()
378
379
380
381
382
383
384
385
386
387
388
389
390
            if self.uses_mrope:
                positions += 1
                # NOTE(woosuk): We should handle the case where the draft model
                # generates tokens beyond the max model length.
                # Since it is complex to remove such requests from the batch,
                # we keep them in the batch but adjust the position ids
                # and slot mappings to avoid the
                # out-of-range access during the model execution.
                # The draft tokens generated with this adjustment
                # should be ignored.
                exceeds_max_model_len = positions[0] >= self.max_model_len
                # Mask out the position ids that exceed the max model length.
                # Otherwise, we may get out-of-range error in RoPE.
391
392
393
394
395
                clamped_positions = torch.where(
                    exceeds_max_model_len.unsqueeze(0),
                    torch.zeros_like(positions),
                    positions,
                )
396
397
398
            else:
                positions += 1
                exceeds_max_model_len = positions >= self.max_model_len
399
                clamped_positions = torch.where(exceeds_max_model_len, 0, positions)
400
401

            # Increment the sequence lengths.
402
403
            common_attn_metadata.seq_lens += 1
            common_attn_metadata.seq_lens_cpu += 1
404
405
            # For the requests that exceed the max model length, we set the
            # sequence length to 1 to minimize their overheads in attention.
406

407
            common_attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1)
408

409
            common_attn_metadata.num_computed_tokens_cpu = (
410
                common_attn_metadata.seq_lens_cpu - 1
411
            )
412

413
            # Compute the slot mapping.
414
415
416
417
418
            if self.uses_mrope:
                # all dimensions of positions are the same
                block_numbers = clamped_positions[0] // self.block_size
            else:
                block_numbers = clamped_positions // self.block_size
419
            block_ids = common_attn_metadata.block_table_tensor.gather(
420
421
                dim=1, index=block_numbers.view(-1, 1)
            )
422
            block_ids = block_ids.view(-1)
423
424
            if self.uses_mrope:
                common_attn_metadata.slot_mapping = (
425
426
                    block_ids * self.block_size + clamped_positions[0] % self.block_size
                )
427
428
            else:
                common_attn_metadata.slot_mapping = (
429
430
                    block_ids * self.block_size + clamped_positions % self.block_size
                )
431
432
433
            # Mask out the slot mappings that exceed the max model length.
            # Otherwise, the KV cache will be inadvertently updated with the
            # padding tokens.
434
            common_attn_metadata.slot_mapping.masked_fill_(
435
436
                exceeds_max_model_len, PADDING_SLOT_ID
            )
437
438

            # Rebuild attention metadata
439
            attn_metadata = attn_metadata_builder.build_for_drafting(  # type: ignore
440
441
                common_attn_metadata=common_attn_metadata, draft_index=token_index + 1
            )
442
443
            for layer_name in self.attn_layer_names:
                per_layer_attn_metadata[layer_name] = attn_metadata
444

445
446
            # copy inputs to buffer for cudagraph
            self.input_ids[:batch_size] = input_ids
447
            self._set_positions(batch_size, clamped_positions)
448
            self.hidden_states[:batch_size] = hidden_states
449
            if self.supports_mm_inputs:
450
                self.inputs_embeds[:batch_size] = self.model.embed_input_ids(input_ids)
451

452
                input_ids = None
453
                inputs_embeds = self.inputs_embeds[:input_batch_size]
454
455
            else:
                input_ids = self.input_ids[:input_batch_size]
456
                inputs_embeds = None
457

458
            # Run the model.
459
            with set_forward_context(
460
461
462
463
                per_layer_attn_metadata,
                self.vllm_config,
                num_tokens=input_batch_size,
                cudagraph_runtime_mode=cudagraph_runtime_mode,
464
            ):
465
                ret_hidden_states = self.model(
466
                    input_ids=input_ids,
467
                    positions=self._get_positions(input_batch_size),
468
469
                    hidden_states=self.hidden_states[:input_batch_size],
                    inputs_embeds=inputs_embeds,
470
                )
471
                if self.method == "mtp":
472
473
474
475
                    last_hidden_states = ret_hidden_states
                    hidden_states = ret_hidden_states
                else:
                    last_hidden_states, hidden_states = ret_hidden_states
476
            hidden_states = hidden_states[:batch_size]
477
            logits = self.model.compute_logits(last_hidden_states[:batch_size])
478
            draft_token_ids = logits.argmax(dim=-1)
479
480
481
482
            draft_token_ids_list.append(draft_token_ids)

        # [batch_size, num_speculative_tokens]
        draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
483
        return draft_token_ids
484

485
    def prepare_next_token_ids_cpu(
486
        self,
Cyrus Leung's avatar
Cyrus Leung committed
487
        sampled_token_ids: list[np.ndarray],
488
489
490
491
        requests: dict[str, CachedRequestState],
        gpu_input_batch: InputBatch,
        num_scheduled_tokens: dict[str, int],
    ) -> torch.Tensor:
492
493
494
495
496
497
498
499
500
501
        """
        This function is used to prepare the inputs for speculative decoding.
        It calculates the next token ids for each request based on the sampled
        token ids from the CPU. If a request has no sampled token ids (e.g.,
        during the initial decoding steps), it falls back to using the request
        state to get the next token id.
        """
        req_ids = gpu_input_batch.req_ids
        next_token_ids: list[int] = []
        for i, token_ids in enumerate(sampled_token_ids):
Cyrus Leung's avatar
Cyrus Leung committed
502
            if token_ids.shape[0] > 0:
503
504
505
506
507
508
509
                # Common case.
                next_token_id = token_ids[-1]
            else:
                # Partial prefill (rare case).
                # Get the next token id from the request state.
                req_id = req_ids[i]
                req_state = requests[req_id]
510
                seq_len = req_state.num_computed_tokens + num_scheduled_tokens[req_id]
511
512
                next_token_id = req_state.get_token_id(seq_len)
            next_token_ids.append(next_token_id)
Cyrus Leung's avatar
Cyrus Leung committed
513
        return torch.tensor(
514
515
            next_token_ids, dtype=torch.int32, device=self.input_ids.device
        )
516

517
518
519
520
521
522
523
524
525
    def prepare_next_token_ids_padded(
        self,
        common_attn_metadata: CommonAttentionMetadata,
        sampled_token_ids: torch.Tensor,
        requests: dict[str, CachedRequestState],
        gpu_input_batch: InputBatch,
        discard_request_indices: torch.Tensor,
        num_discarded_requests: int,
    ) -> tuple[torch.Tensor, torch.Tensor]:
526
527
528
529
530
531
532
533
534
535
536
537
538
        """
        This function is used to prepare the inputs for speculative decoding.
        It calculates the next token ids and the number of valid sampled tokens
        for each request, considering the "discarded" requests whose next token
        is not sampled and comes from `request.get_token_id()` instead.
        It also accounts for the rejected tokens in `sampled_token_ids`.
        This function must use device functions to operate on the inputs, and
        should not introduce any blocking CPU-GPU synchronization.
        """
        # TODO(Ben): Combine this into a custom fused kernel

        # Precompute get_token_id for when there is no valid next token
        num_reqs = gpu_input_batch.num_reqs
539
540
541
542
543
544
545
546
        self.backup_next_token_ids.np[:num_reqs] = np.array(
            [
                requests[gpu_input_batch.req_ids[i]].get_token_id(
                    common_attn_metadata.seq_lens_cpu[i].item()
                )
                for i in range(num_reqs)
            ]
        )
547
548
549
        self.backup_next_token_ids.copy_to_gpu(num_reqs)

        # Mask out the sampled tokens indices that should not be sampled.
550
551
552
        discard_sampled_tokens_req_indices = discard_request_indices[
            :num_discarded_requests
        ]
553
554
555

        valid_sampled_token_ids_gpu = sampled_token_ids.clone()
        valid_sampled_token_ids_gpu.index_fill_(
556
557
            0, discard_sampled_tokens_req_indices, -1
        )
558
559

        # Generate a mask for all valid tokens within those requests
560
561
562
        valid_mask = (valid_sampled_token_ids_gpu != -1) & (
            valid_sampled_token_ids_gpu < gpu_input_batch.vocab_size
        )
563
564
565
566
567
568
569
570
571
572
573

        # Count the number of valid tokens in each request
        valid_sampled_tokens_count = valid_mask.sum(dim=1)

        # Get the rightmost valid index per row
        last_valid_indices = valid_sampled_tokens_count - 1
        last_valid_indices_safe = torch.clamp(last_valid_indices, min=0)

        # Get last valid token from each row
        # (assume undefined state where there is no valid token)
        selected_tokens = torch.gather(
574
575
            valid_sampled_token_ids_gpu, 1, last_valid_indices_safe.unsqueeze(1)
        ).squeeze(1)
576
577
578
579

        # Use last token if valid, pre-computed backup if not
        batch_size = valid_sampled_token_ids_gpu.shape[0]
        next_token_ids = torch.where(
580
581
582
583
            last_valid_indices != -1,
            selected_tokens,
            self.backup_next_token_ids.gpu[:batch_size],
        )
584
585
586

        return next_token_ids, valid_sampled_tokens_count

587
588
589
590
591
592
    def prepare_inputs_padded(
        self,
        common_attn_metadata: CommonAttentionMetadata,
        spec_decode_metadata: SpecDecodeMetadata,
        valid_sampled_tokens_count: torch.Tensor,
    ) -> tuple[CommonAttentionMetadata, torch.Tensor, torch.Tensor]:
593
594
595
596
597
598
599
600
        """
        This function is used to prepare the inputs for speculative decoding
        It updates the common_attn_metadata for speculative decoding,
        but does not consider the rejected tokens. Instead, all tokens
        are included as inputs to the speculator, with the rejected tokens
        used as padding and filtered out later by `token_indices_to_sample`.
        No blocking CPU operations should be introduced in this function.
        """
601
602
603
604
605
606
607
        num_draft_tokens_gpu = torch.cat(
            [
                spec_decode_metadata.cu_num_draft_tokens[0:1],
                spec_decode_metadata.cu_num_draft_tokens[1:]
                - spec_decode_metadata.cu_num_draft_tokens[:-1],
            ]
        )
608
609
610
611

        num_rejected_tokens_gpu = torch.where(
            num_draft_tokens_gpu > 0,
            num_draft_tokens_gpu + 1 - valid_sampled_tokens_count,
612
613
            torch.zeros_like(num_draft_tokens_gpu),
        )
614
615
616

        query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu

617
        new_query_len_per_req = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
618
619
620
621
622
623
624
625
626

        total_num_tokens = query_start_loc_cpu[-1].item()
        token_indices = self.arange[:total_num_tokens]

        spec_common_attn_metadata = CommonAttentionMetadata(
            query_start_loc=common_attn_metadata.query_start_loc,
            seq_lens=common_attn_metadata.seq_lens,
            query_start_loc_cpu=query_start_loc_cpu,
            seq_lens_cpu=common_attn_metadata.seq_lens_cpu,
627
            num_computed_tokens_cpu=common_attn_metadata.num_computed_tokens_cpu,
628
629
630
631
632
633
634
            num_reqs=common_attn_metadata.num_reqs,
            num_actual_tokens=total_num_tokens,
            max_query_len=new_query_len_per_req.max().item(),
            max_seq_len=common_attn_metadata.seq_lens_cpu.max().item(),
            block_table_tensor=common_attn_metadata.block_table_tensor,
            slot_mapping=common_attn_metadata.slot_mapping[token_indices],
            causal=True,
635
            dcp_local_seq_lens=common_attn_metadata.dcp_local_seq_lens,
636
637
        )

638
639
640
        token_indices_to_sample = (
            common_attn_metadata.query_start_loc[1:] - 1 - num_rejected_tokens_gpu
        )
641
642
643

        return spec_common_attn_metadata, token_indices, token_indices_to_sample

644
645
646
647
648
649
650
651
652
653
654
    def propose_tree(
        self,
        batch_size: int,
        # [num_tokens, vocab_size]
        logits: torch.Tensor,
        # [num_tokens]
        positions: torch.Tensor,
        # [num_tokens, hidden_size]
        hidden_states: torch.Tensor,
        common_attn_metadata: CommonAttentionMetadata,
    ) -> list[torch.Tensor]:
655
656
657
658
        tree_attn_metadata_builder = self.runner.attn_groups[0][
            0
        ].get_metadata_builder()
        assert isinstance(tree_attn_metadata_builder, TreeAttentionMetadataBuilder)
659

660
        total_num_drafts = self.cu_drafts_per_level[0]
661
662
        level_num_drafts = total_num_drafts
        # Sample a draft token for each child at the tree root level.
663
        num_children = self.child_drafts_per_level[0]
664
665
666
        if num_children == 1:
            draft_token_ids = logits.argmax(dim=-1).view(batch_size, -1)
        else:
667
668
669
            draft_token_ids = torch.topk(logits, num_children, dim=-1).indices.view(
                batch_size, -1
            )
670
671
672
673
        draft_token_ids_list = [draft_token_ids]
        draft_hidden_states = hidden_states.view(batch_size, 1, -1)

        # Initialize empty tensors for concatenation with the level outputs.
674
675
676
677
678
679
680
681
682
        tree_input_ids = torch.empty(
            0, device=self.input_ids.device, dtype=self.input_ids.dtype
        )
        tree_positions = torch.empty(
            0, device=self.positions.device, dtype=self.positions.dtype
        )
        tree_hidden_states = torch.empty(
            0, device=self.hidden_states.device, dtype=self.hidden_states.dtype
        )
683
684
        # Precompute the draft token positions.
        flattened_draft_positions = (
685
686
            positions.view(batch_size, -1) + self.tree_draft_pos_offsets[:batch_size, :]
        )
687
        tree_depth = len(self.cu_drafts_per_level)
688
        for level in range(tree_depth - 1):
689
690
            # Get draft positions for RoPE.
            draft_positions = positions + (level + 1)
691
            exceeds_max_model_len = (positions + total_num_drafts) >= self.max_model_len
692
693
            # Mask out the position ids that exceed the max model length.
            # Otherwise, we may get out-of-range error in RoPE.
694
            draft_positions = torch.where(
695
696
697
                exceeds_max_model_len,
                0,
                draft_positions,
698
699
            ).view(batch_size, -1)

700
701
            if level_num_drafts > 1:
                # Repeat the positions for each draft at this level.
702
                draft_positions = draft_positions.repeat_interleave(
703
704
                    level_num_drafts, dim=1
                )
705
706
707
708

            if num_children > 1:
                # Repeat draft hidden states for each child.
                draft_hidden_states = draft_hidden_states.repeat_interleave(
709
710
                    num_children, dim=1
                )
711
712

            # Concatenate the draft tokens, positions, and hidden states.
713
714
            tree_input_ids = torch.cat([tree_input_ids, draft_token_ids], dim=1)
            tree_positions = torch.cat([tree_positions, draft_positions], dim=1)
715
            tree_hidden_states = torch.cat(
716
717
                [tree_hidden_states, draft_hidden_states], dim=1
            )
718
719
720

            # Build new attention metadata for the next level of drafts.
            # This is necessary to support tree attention.
721
            query_len = total_num_drafts
722
723
            common_attn_metadata = replace(
                common_attn_metadata,
724
                query_start_loc=query_len * self.arange[: batch_size + 1],
725
726
727
728
729
730
                seq_lens=common_attn_metadata.seq_lens + level_num_drafts,
                num_actual_tokens=batch_size * query_len,
                max_query_len=query_len,
            )
            attn_metadata = tree_attn_metadata_builder.build_for_drafting(
                common_attn_metadata=common_attn_metadata,
731
                draft_index=level + 1,
732
733
734
735
736
737
738
739
            )

            # Apply new attention metadata to all layers.
            per_layer_attn_metadata = {}
            for layer_name in self.attn_layer_names:
                per_layer_attn_metadata[layer_name] = attn_metadata

            # Consider max model length.
740
741
742
            attn_metadata.max_seq_len = min(
                attn_metadata.max_seq_len, self.max_model_len
            )
743
744
745
746
747
            # For the requests that exceed the max model length, we set the
            # sequence length to 1 to minimize their overheads in attention.
            attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1)

            # Compute the slot mapping.
748
            query_positions = flattened_draft_positions[:, level : level + query_len]
749
            block_numbers = query_positions // self.block_size
750
751
752
753
            block_ids = attn_metadata.block_table.gather(dim=1, index=block_numbers)
            slot_mapping = (
                block_ids * self.block_size + query_positions % self.block_size
            )
754
755
756
757
758
759
760
761
762
763
764
            # Mask out the slot mappings that exceed the max model length.
            # Otherwise, the KV cache will be inadvertently updated with the
            # padding tokens.
            slot_mapping[exceeds_max_model_len] = PADDING_SLOT_ID
            attn_metadata.slot_mapping = slot_mapping.view(-1)

            # Copy inputs to buffer for cudagraph.
            num_tokens = attn_metadata.num_actual_tokens
            input_ids = tree_input_ids.view(-1)
            self.input_ids[:num_tokens] = input_ids
            self.positions[:num_tokens] = tree_positions.view(-1)
765
            self.hidden_states[:num_tokens] = tree_hidden_states.view(num_tokens, -1)
766

767
768
            if self.use_cuda_graph and num_tokens <= self.cudagraph_batch_sizes[-1]:
                num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
769
                cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
770
771
            else:
                num_input_tokens = num_tokens
772
                cudagraph_runtime_mode = CUDAGraphMode.NONE
773
            # Run the model.
774
            with set_forward_context(
775
776
777
778
                per_layer_attn_metadata,
                self.vllm_config,
                num_tokens=num_input_tokens,
                cudagraph_runtime_mode=cudagraph_runtime_mode,
779
            ):
780
781
782
783
784
785
786
787
788
                last_hidden_states, hidden_states = self.model(
                    input_ids=self.input_ids[:num_input_tokens],
                    positions=self.positions[:num_input_tokens],
                    hidden_states=self.hidden_states[:num_input_tokens],
                    inputs_embeds=None,
                )

            # Get the output hidden states for the draft tokens.
            draft_hidden_states = hidden_states[:num_tokens].view(
789
790
                batch_size, query_len, -1
            )[:, -level_num_drafts:]
791
            draft_last_hidden_states = last_hidden_states[:num_tokens].view(
792
793
                batch_size, query_len, -1
            )[:, -level_num_drafts:]
794
795
796

            # Get the output logits for the draft tokens.
            logits = self.model.compute_logits(
797
798
                draft_last_hidden_states.reshape(batch_size * level_num_drafts, -1)
            )
799
800
801
802
803
804

            # Sample a draft token for each child at the next tree level.
            num_children = self.child_drafts_per_level[level + 1]
            if num_children == 1:
                draft_token_ids = logits.argmax(dim=-1).view(batch_size, -1)
            else:
805
806
807
                draft_token_ids = torch.topk(logits, num_children, dim=-1).indices.view(
                    batch_size, -1
                )
808
809
810
            draft_token_ids_list.append(draft_token_ids)

            # Update the # drafts counters for the next tree level.
811
            level_num_drafts = self.cu_drafts_per_level[level + 1] - total_num_drafts
812
813
814
            total_num_drafts = self.cu_drafts_per_level[level + 1]
        return draft_token_ids_list

815
    def prepare_inputs(
816
817
        self,
        common_attn_metadata: CommonAttentionMetadata,
818
819
        sampled_token_ids: list[list[int]],
        num_draft_tokens: list[int],
820
821
    ) -> tuple[CommonAttentionMetadata, torch.Tensor]:
        """
822
        This function is used to prepare the inputs for speculative decoding.
823
824
825
826
827
828
        It updates to the common_attn_metadata to account for the rejected
        tokens (and newly sampled tokens). It also returns the token indices
        of the tokens that should be fed to the speculator.
        """
        # E.g.
        #  common_attn_metadata.query_start_loc{_cpu}:
829
        #       [0, q1, q1 + q2, q1 + q2 + q3]
830
831
832
833
834
835
        #  common_attn_metadata.seq_lens{_cpu}: [s1, s2, s3]
        #  num_rejected_tokens: [n1, n2, n3]
        # This function computes the intermediate values:
        #  num_tokens_per_req: [q1 - n1, q2 - n2, q3 - n3]
        # And returns:
        #  common_attn_metadata.query_start_loc{_cpu}:
836
        #       [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
837
        #  common_attn_metadata.seq_lens{_cpu}:
838
        #       [s1 - n1 + 1, s2 - n2 + 1, s3 - n3 + 1]
839
        #  token_indices: [0, 1, ..., q1 - n1 - 1,
840
841
        #                 q1, q1 + 1, ..., q1 + q2 - n2 - 1,
        #                 q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1]
842

843
844
845
846
        num_rejected_tokens = [
            n + 1 - len(sampled_token_ids[i]) if n > 0 else 0
            for i, n in enumerate(num_draft_tokens)
        ]
847
        num_rejected_tokens = torch.tensor(num_rejected_tokens, dtype=torch.int32)
848

849
850
        device = common_attn_metadata.query_start_loc.device
        query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
851
        new_seq_lens_cpu = common_attn_metadata.seq_lens_cpu - num_rejected_tokens
852
853

        # [0, q1, q1 + q2, q1 + q2 + q3] -> [q1, q2, q3]
854
        new_query_len_per_req = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
855
856
857
858
859
860
861
862
        # [q1, q2, q3] -> [q1 - n1, q2 - n2, q3 - n3]
        new_num_tokens_per_req = new_query_len_per_req - num_rejected_tokens
        new_num_tokens_per_req_np = new_num_tokens_per_req.numpy()

        # [q1 - n1, q2 - n2, q3 - n3] ->
        # [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
        new_query_start_loc_cpu = torch.zeros(
            query_start_loc_cpu.shape,
863
            dtype=torch.int32,
864
865
            pin_memory=is_pin_memory_available(),
        )
866
867
868
869
870
871
872
873
874
        new_query_start_loc_np = new_query_start_loc_cpu.numpy()
        np.cumsum(new_num_tokens_per_req_np, out=new_query_start_loc_np[1:])

        total_num_tokens = new_query_start_loc_np[-1]
        # Example assuming num_tokens_per_req_np = [2, 4, 3]
        # this implies that `new_query_start_locs` is:
        # [0, 2, 6, 9] ->
        # [0, 0, 2, 2, 2, 2, 6, 6, 6]
        #  _r1_  ____r2____  ___r3__
875
876
877
        new_query_start_locs_expanded = np.repeat(
            new_query_start_loc_np[:-1], new_num_tokens_per_req_np
        )
878
879
880
        # [0, 1, 2, 3, 4, 5, 6, 7, 8] ->
        # [0, 1, 0, 1, 2, 3, 0, 1, 2]
        #  _r1_  ____r2____  ___r3__
881
882
883
        token_offests = (
            self.token_arange_np[:total_num_tokens] - new_query_start_locs_expanded
        )
884
885
886
887
888
889

        # Expand starting positions to match token pattern
        # [0, q1, q1 + q2] ->
        # [0, 0, q1, q1, q1, q1, q1 + q2, q1 + q2, q1 + q2]
        #  _r1_  _____r2_______  ___________r3____________
        old_query_start_locs_expanded = np.repeat(
890
891
            query_start_loc_cpu[:-1].numpy(), new_num_tokens_per_req_np
        )
892
        # Final token indices are:
893
894
895
        # [0, 1,                                // req 1
        #  q1 + 0, q1 + 1, q1 + 2, q1 + 3,       // req 2
        #  q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3
896
        token_indices_np = token_offests + old_query_start_locs_expanded
897
        token_indices = torch.from_numpy(token_indices_np).to(device, non_blocking=True)
898
899

        spec_common_attn_metadata = CommonAttentionMetadata(
900
            query_start_loc=new_query_start_loc_cpu.to(device, non_blocking=True),
901
902
903
            seq_lens=new_seq_lens_cpu.to(device, non_blocking=True),
            query_start_loc_cpu=new_query_start_loc_cpu,
            seq_lens_cpu=new_seq_lens_cpu,
904
            num_computed_tokens_cpu=common_attn_metadata.num_computed_tokens_cpu,
905
906
907
            num_reqs=common_attn_metadata.num_reqs,
            num_actual_tokens=total_num_tokens,
            max_query_len=new_query_len_per_req.max().item(),
908
            max_seq_len=new_seq_lens_cpu.max().item(),
909
910
            block_table_tensor=common_attn_metadata.block_table_tensor,
            slot_mapping=common_attn_metadata.slot_mapping[token_indices],
911
            causal=True,
912
            dcp_local_seq_lens=common_attn_metadata.dcp_local_seq_lens,
913
        )
914
915

        return spec_common_attn_metadata, token_indices
916

917
    def get_model_name(self, model: nn.Module) -> str:
918
        if hasattr(model, "module"):  # multi-GPU
919
920
921
            model = model.module
        return model.__class__.__name__

922
    def load_model(self, target_model: nn.Module) -> None:
923
        draft_model_config = self.vllm_config.speculative_config.draft_model_config
924
        target_attn_layer_names = set(
925
            get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase).keys()
926
        )
927
928
        # FIXME: support hybrid kv for draft model
        target_indexer_layer_names = set(
929
930
931
932
            get_layers_from_vllm_config(
                self.vllm_config, DeepseekV32IndexerCache
            ).keys()
        )
933

934
        from vllm.compilation.backends import set_model_tag
935

936
        with set_model_tag("eagle_head"):
937
938
939
            self.model = get_model(
                vllm_config=self.vllm_config, model_config=draft_model_config
            )
940

941
        draft_attn_layer_names = (
942
            get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase).keys()
943
944
945
946
947
948
            - target_attn_layer_names
        )
        indexer_layers = get_layers_from_vllm_config(
            self.vllm_config, DeepseekV32IndexerCache
        )
        draft_indexer_layer_names = indexer_layers.keys() - target_indexer_layer_names
949
        self.attn_layer_names = list(draft_attn_layer_names - draft_indexer_layer_names)
950
951
952
953
954
        self.indexer_layer_names = list(draft_indexer_layer_names)

        if self.indexer_layer_names:
            first_layer = self.indexer_layer_names[0]
            self.draft_indexer_metadata_builder = (
955
956
957
                indexer_layers[first_layer]
                .get_attn_backend()
                .get_builder_cls()(
958
                    indexer_layers[first_layer].get_kv_cache_spec(self.vllm_config),
959
960
961
                    self.indexer_layer_names,
                    self.vllm_config,
                    self.device,
962
963
                )
            )
964
965
        else:
            self.draft_indexer_metadata_builder = None
966

967
        if self.supports_mm_inputs:
968
969
970
            # Even if the target model is multimodal, we can also use
            # text-only draft models
            try:
971
                dummy_input_ids = torch.tensor([[1]], device=self.input_ids.device)
972
                self.model.embed_input_ids(dummy_input_ids, multimodal_embeddings=None)
973
974
975
            except (NotImplementedError, AttributeError, TypeError):
                logger.warning(
                    "Draft model does not support multimodal inputs, "
976
977
                    "falling back to text-only mode"
                )
978
                self.supports_mm_inputs = False
979

980
981
        if supports_multimodal(target_model):
            # handle multimodality
982
983
984
985
986
            if (
                self.get_model_name(target_model)
                == "Qwen2_5_VLForConditionalGeneration"
            ):
                self.model.config.image_token_index = target_model.config.image_token_id
987
988
            else:
                self.model.config.image_token_index = (
989
990
                    target_model.config.image_token_index
                )
991
992
993
            target_language_model = target_model.get_language_model()
        else:
            target_language_model = target_model
994
        # share embed_tokens with the target model if needed
995
        if get_pp_group().world_size == 1:
996
            if hasattr(target_language_model.model, "embed_tokens"):
997
                target_embed_tokens = target_language_model.model.embed_tokens
998
            elif hasattr(target_language_model.model, "embedding"):
999
1000
1001
                target_embed_tokens = target_language_model.model.embedding
            else:
                raise AttributeError(
1002
1003
                    "Target model does not have 'embed_tokens' or 'embedding' attribute"
                )
1004
1005
1006
1007
1008
1009
1010

            # Check if shapes match and we found the embedding
            eagle_shape = self.model.model.embed_tokens.weight.shape
            target_shape = target_embed_tokens.weight.shape
            if eagle_shape == target_shape:
                logger.info(
                    "Assuming the EAGLE head shares the same vocab embedding"
1011
1012
                    " with the target model."
                )
1013
1014
1015
1016
1017
                del self.model.model.embed_tokens
                self.model.model.embed_tokens = target_embed_tokens
            else:
                logger.info(
                    "The EAGLE head's vocab embedding will be loaded separately"
1018
1019
                    " from the target model."
                )
1020
        else:
1021
            logger.info(
1022
                "The EAGLE head's vocab embedding will be loaded separately"
1023
1024
                " from the target model."
            )
1025
1026
1027
1028

        # share lm_head with the target model if needed
        # some model definition do not define lm_head explicitly
        # and reuse embed_tokens for lm_head, e.g., CohereForCausalLM
1029
1030
        if self.vllm_config.speculative_config.method != "eagle3":
            if hasattr(target_language_model, "lm_head"):
1031
                logger.info("Loading EAGLE LM head weights from the target model.")
1032
1033
                self.model.lm_head = target_language_model.lm_head
        else:
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
            if (
                hasattr(self.model, "lm_head")
                and hasattr(target_language_model, "lm_head")
                and self.model.lm_head.weight.shape
                == target_language_model.lm_head.weight.shape
            ):
                logger.info(
                    "Assuming the EAGLE head shares the same lm_head"
                    " with the target model."
                )
1044
1045
1046
1047
1048
                del self.model.lm_head
                self.model.lm_head = target_language_model.lm_head
            else:
                logger.info(
                    "The EAGLE head's lm_head will be loaded separately"
1049
1050
                    " from the target model."
                )
1051

1052
1053
1054
1055
    @torch.inference_mode()
    def dummy_run(
        self,
        num_tokens: int,
1056
        use_cudagraphs=True,
1057
    ) -> None:
1058
1059
1060
        # Determine if CUDA graphs should be used for this run.
        cudagraphs_enabled = use_cudagraphs and self.use_cuda_graph
        if cudagraphs_enabled and num_tokens <= self.cudagraph_batch_sizes[-1]:
1061
1062
1063
1064
1065
1066
            num_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)

        with set_forward_context(
            None,
            self.vllm_config,
            num_tokens=num_tokens,
1067
1068
1069
            cudagraph_runtime_mode=(
                CUDAGraphMode.PIECEWISE if cudagraphs_enabled else CUDAGraphMode.NONE
            ),
1070
        ):
1071
            if self.supports_mm_inputs:
1072
1073
1074
1075
1076
1077
                input_ids = None
                inputs_embeds = self.inputs_embeds[:num_tokens]
            else:
                input_ids = self.input_ids[:num_tokens]
                inputs_embeds = None

1078
            self.model(
1079
                input_ids=input_ids,
1080
                positions=self._get_positions(num_tokens),
1081
1082
                hidden_states=self.hidden_states[:num_tokens],
                inputs_embeds=inputs_embeds,
1083
            )
1084

1085
    def _get_attention_metadata_builder(self) -> AttentionMetadataBuilder:
1086
        """Find and return the attention metadata builders for EAGLE layers.
1087

1088
1089
        Returns:
            The metadata builders for EAGLE layers.
1090

1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
        Raises:
            AssertionError: If no metadata builders are found for EAGLE layers.
        """
        builder = None
        chosen_layer = self.attn_layer_names[0]

        for kv_cache_group in self.runner.attn_groups:
            for attn_group in kv_cache_group:
                if chosen_layer in attn_group.layer_names:
                    builder = attn_group.get_metadata_builder()
                    break
            if builder is not None:
                break

        assert builder is not None, (
1106
1107
            "Failed to find attention metadata builder for EAGLE layers."
        )
1108
1109
        return builder

1110
    def validate_same_kv_cache_group(self, kv_cache_config: KVCacheConfig) -> None:
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
        """
        Validate that all eagle layers belong to the same KVCacheGroup.
        Need this assumption to ensure all eagle layers can use the
        same AttentionMetadata.
        May extend to multiple AttentionMetadata in the future.
        """
        kv_cache_groups: dict[str, int] = {}
        for id, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups):
            for layer_name in kv_cache_group.layer_names:
                kv_cache_groups[layer_name] = id
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
        assert (
            len(
                set(
                    [
                        kv_cache_groups[layer_name]
                        for layer_name in self.attn_layer_names
                    ]
                )
            )
            == 1
        ), "All eagle layers should belong to the same kv cache group"
1132

1133

1134
1135
1136
1137
# NOTE(woosuk): Currently, the below code is not used and we always use argmax
# to sample the draft tokens. We will use this after we find a way to manage
# the draft prob tensor.
# Refer to https://github.com/vllm-project/vllm/pull/16899 for the details.
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
# FIXME(woosuk): The logic here is duplicated with the main sampling code.
# We should refactor this to reuse the same sampling implementation.
def compute_probs_and_sample_next_token(
    logits: torch.Tensor,
    sampling_metadata: SamplingMetadata,
) -> tuple[torch.Tensor, torch.Tensor]:
    if sampling_metadata.all_greedy:
        # For greedy requests, draft_probs is not used in rejection sampling.
        # Therefore, we can just return the logits.
        probs = logits
        next_token_ids = logits.argmax(dim=-1)
        return next_token_ids, probs

1151
1152
1153
1154
1155
1156
1157
1158
1159
    assert sampling_metadata.temperature is not None

    # Use epsilon comparison to detect greedy sampling (temperature ~ 0.0)
    # consistent with sampler.py's _SAMPLING_EPS threshold
    temperature = sampling_metadata.temperature
    # Avoid division by zero if there are greedy requests.
    if not sampling_metadata.all_random:
        is_greedy = temperature < _SAMPLING_EPS
        temperature = torch.where(is_greedy, 1.0, temperature)
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
    logits.div_(temperature.view(-1, 1))
    probs = logits.softmax(dim=-1, dtype=torch.float32)

    # NOTE(woosuk): Currently, we ignore most of the sampling parameters in
    # generating the draft tokens. We only use the temperature. While this
    # could degrade the acceptance rate, it does not affect the distribution
    # of the generated tokens after rejection sampling.

    # TODO(woosuk): Consider seeds.
    q = torch.empty_like(probs)
    q.exponential_()
1171
1172
1173
    # NOTE(woosuk): We shouldn't use `probs.div_(q)` because the draft_probs
    # will be used later for rejection sampling.
    next_token_ids = probs.div(q).argmax(dim=-1).view(-1)
1174
1175
1176
1177
1178
1179
1180
1181
    if not sampling_metadata.all_random:
        greedy_token_ids = probs.argmax(dim=-1)
        next_token_ids = torch.where(
            is_greedy,
            greedy_token_ids,
            next_token_ids,
        )
    return next_token_ids, probs