eagle.py 44.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
import ast
from dataclasses import replace
5
from importlib.util import find_spec
6
from typing import Optional
7

8
import numpy as np
9
10
11
import torch
import torch.nn as nn

12
13
from vllm.attention.layer import Attention
from vllm.config import (CompilationLevel, VllmConfig,
14
                         get_layers_from_vllm_config)
15
from vllm.distributed.parallel_state import get_pp_group
16
from vllm.forward_context import set_forward_context
17
from vllm.logger import init_logger
18
from vllm.model_executor.model_loader import get_model
19
from vllm.model_executor.models import supports_multimodal
20
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
21
from vllm.platforms import current_platform
22
from vllm.utils import is_pin_memory_available
23
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
24
25
from vllm.v1.attention.backends.tree_attn import (TreeAttentionMetadata,
                                                  TreeAttentionMetadataBuilder)
26
from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata
27
28
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
                                              CommonAttentionMetadata)
29
from vllm.v1.kv_cache_interface import KVCacheConfig
30
from vllm.v1.sample.metadata import SamplingMetadata
31
32
33
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.utils import CpuGpuBuffer
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
34

35
36
logger = init_logger(__name__)

37
38
PADDING_SLOT_ID = -1

39
40
41
42
43
44
45

class EagleProposer:

    def __init__(
        self,
        vllm_config: VllmConfig,
        device: torch.device,
Jiayi Yao's avatar
Jiayi Yao committed
46
        runner=None,
47
48
    ):
        self.vllm_config = vllm_config
49
50
51
        self.speculative_config = vllm_config.speculative_config
        self.draft_model_config = self.speculative_config.draft_model_config
        self.method = self.speculative_config.method
52

Jiayi Yao's avatar
Jiayi Yao committed
53
        self.runner = runner
54
        self.dtype = vllm_config.model_config.dtype
55
56
57
58
59
60
        self.max_model_len = vllm_config.model_config.max_model_len
        self.block_size = vllm_config.cache_config.block_size
        self.num_speculative_tokens = (
            self.speculative_config.num_speculative_tokens)
        self.max_num_tokens = (
            vllm_config.scheduler_config.max_num_batched_tokens)
61
        self.token_arange_np = np.arange(self.max_num_tokens)
62
63
64
65
        # We need to get the hidden size from the draft model config because
        # the draft model's hidden size can be different from the target model's
        # hidden size (e.g., Llama 3.3 70B).
        self.hidden_size = self.draft_model_config.get_hidden_size()
66

67
68
69
        self.is_multimodal_model = vllm_config.model_config \
            .is_multimodal_model

70
71
        self.attn_metadata_builder: Optional[AttentionMetadataBuilder] = None

72
73
74
        self.use_cuda_graph = (self.vllm_config.compilation_config.level
                               == CompilationLevel.PIECEWISE and
                               not self.vllm_config.model_config.enforce_eager)
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
        self.cudagraph_batch_sizes = list(
            reversed(
                self.vllm_config.compilation_config.cudagraph_capture_sizes))

        # persistent buffers for cuda graph
        self.input_ids = torch.zeros(self.max_num_tokens,
                                     dtype=torch.int32,
                                     device=device)
        self.positions = torch.zeros(self.max_num_tokens,
                                     dtype=torch.int64,
                                     device=device)
        self.hidden_states = torch.zeros(
            (self.max_num_tokens, self.hidden_size),
            dtype=self.dtype,
            device=device)
90

91
92
        # We need +1 here because the arange is used to set query_start_loc,
        # which has one more element than batch_size.
93
        max_batch_size = vllm_config.scheduler_config.max_num_seqs
94
95
96
97
        max_num_slots_for_arange = max(max_batch_size + 1, self.max_num_tokens)
        self.arange = torch.arange(max_num_slots_for_arange,
                                   device=device,
                                   dtype=torch.int32)
98

99
100
101
102
103
        self.inputs_embeds = torch.zeros(
            (self.max_num_tokens, self.hidden_size),
            dtype=self.dtype,
            device=device)

104
105
106
107
108
109
110
        self.backup_next_token_ids = CpuGpuBuffer(
            max_batch_size,
            dtype=torch.int32,
            pin_memory=is_pin_memory_available(),
            device=device,
            with_numpy=True)

111
        # Determine allowed attention backends once during initialization.
112
        self.allowed_attn_types: Optional[tuple] = None
113
114
115
116
117
118
119
120
121
        if current_platform.is_rocm():
            rocm_types = [TritonAttentionMetadata, FlashAttentionMetadata]
            # vllm.v1.attention.backends.rocm_aiter_fa is an optional backend
            if find_spec("vllm.v1.attention.backends.rocm_aiter_fa"):
                from vllm.v1.attention.backends.rocm_aiter_fa import (
                    AiterFlashAttentionMetadata)
                rocm_types.append(AiterFlashAttentionMetadata)
            self.allowed_attn_types = tuple(rocm_types)

122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
        # Parse the speculative token tree.
        spec_token_tree = self.speculative_config.speculative_token_tree
        self.tree_choices: list[tuple[int,
                                      ...]] = ast.literal_eval(spec_token_tree)
        tree_depth = len(self.tree_choices[-1])
        # Precompute per-level properties of the tree.
        num_drafts_per_level = [0] * tree_depth
        for node in self.tree_choices:
            num_drafts_per_level[len(node) - 1] += 1
        self.cu_drafts_per_level = [num_drafts_per_level[0]]
        self.child_drafts_per_level = [num_drafts_per_level[0]]
        for level in range(1, tree_depth):
            self.cu_drafts_per_level.append(self.cu_drafts_per_level[-1] +
                                            num_drafts_per_level[level])
            self.child_drafts_per_level.append(num_drafts_per_level[level] //
                                               num_drafts_per_level[level - 1])
        # Precompute draft position offsets in flattened tree.
        self.tree_draft_pos_offsets = torch.arange(
            1,
            len(self.tree_choices) + 1,
            device=device,
            dtype=torch.int32,
        ).repeat(max_batch_size, 1)

146
147
148
149
150
151
152
153
154
155
    def propose(
        self,
        # [num_tokens]
        target_token_ids: torch.Tensor,
        # [num_tokens]
        target_positions: torch.Tensor,
        # [num_tokens, hidden_size]
        target_hidden_states: torch.Tensor,
        # [batch_size]
        next_token_ids: torch.Tensor,
156
        last_token_indices: Optional[torch.Tensor],
157
        common_attn_metadata: CommonAttentionMetadata,
158
        sampling_metadata: SamplingMetadata,
159
        mm_embeds: Optional[list[torch.Tensor]] = None,
160
    ) -> torch.Tensor:
161
162
        num_tokens = target_token_ids.shape[0]
        batch_size = next_token_ids.shape[0]
163
164
165

        if last_token_indices is None:
            last_token_indices = common_attn_metadata.query_start_loc[1:] - 1
166

167
168
169
170
171
        if self.method == "eagle3":
            assert isinstance(self.model, Eagle3LlamaForCausalLM)
            target_hidden_states = self.model.combine_hidden_states(
                target_hidden_states)
            assert target_hidden_states.shape[-1] == self.hidden_size
172
173
        # Shift the input ids by one token.
        # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
174
        self.input_ids[:num_tokens - 1] = target_token_ids[1:]
175
176
        # Replace the last token with the next token.
        # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
177
        self.input_ids[last_token_indices] = next_token_ids
178

179
        assert self.runner is not None
Jiayi Yao's avatar
Jiayi Yao committed
180

181
182
183
184
185
        # Select the correct attention metadata builders for EAGLE layers.
        # Get the attention metadata builders once and reuse for later.
        builder = (self._get_attention_metadata_builder()
                   if self.attn_metadata_builder is None else
                   self.attn_metadata_builder)
186
187
188
        attn_metadata = builder.build_for_drafting(  # type: ignore
            common_attn_metadata=common_attn_metadata,
            draft_index=0)
Jiayi Yao's avatar
Jiayi Yao committed
189

190
191
192
193
194
        # At this moment, we assume all eagle layers belong to the same KV
        # cache group, thus using the same attention metadata.
        per_layer_attn_metadata = {}
        for layer_name in self.attn_layer_names:
            per_layer_attn_metadata[layer_name] = attn_metadata
195
        if self.use_cuda_graph and \
196
                num_tokens <= self.cudagraph_batch_sizes[-1]:
197
198
199
200
201
            num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
        else:
            num_input_tokens = num_tokens
        # copy inputs to buffer for cudagraph
        self.positions[:num_tokens] = target_positions
202
        self.hidden_states[:num_tokens] = target_hidden_states
203
204
205
206
207
208
209
210
211
212
213
214
        if self.is_multimodal_model:
            input_ids = self.input_ids[:num_tokens]
            inputs_embeds = self.model.get_input_embeddings(
                input_ids,
                multimodal_embeddings=mm_embeds or None,
            )
            self.inputs_embeds[:num_tokens] = inputs_embeds
            inputs_embeds = self.inputs_embeds[:num_input_tokens]
            input_ids = None
        else:
            inputs_embeds = None
            input_ids = self.input_ids[:num_input_tokens]
215

216
        with set_forward_context(per_layer_attn_metadata,
217
218
                                 self.vllm_config,
                                 num_tokens=num_input_tokens):
Jiayi Yao's avatar
Jiayi Yao committed
219
            ret_hidden_states = self.model(
220
221
222
223
                input_ids=input_ids,
                positions=self.positions[:num_input_tokens],
                hidden_states=self.hidden_states[:num_input_tokens],
                inputs_embeds=inputs_embeds,
224
            )
225
            if self.method == "mtp":
Jiayi Yao's avatar
Jiayi Yao committed
226
                last_hidden_states = ret_hidden_states
227
                hidden_states = last_hidden_states
Jiayi Yao's avatar
Jiayi Yao committed
228
229
            else:
                last_hidden_states, hidden_states = ret_hidden_states
230
        sample_hidden_states = last_hidden_states[last_token_indices]
231
        logits = self.model.compute_logits(sample_hidden_states)
232
233
234
235
236
237

        # Early exit if there is only one draft token to be generated.
        if self.num_speculative_tokens == 1:
            draft_token_ids = logits.argmax(dim=-1)
            return draft_token_ids.view(-1, 1)

238
        positions = target_positions[last_token_indices]
XuruiYang's avatar
XuruiYang committed
239
240
241
242
        if self.method in ("deepseek_mtp", "ernie_mtp", "longcat_flash_mtp"):
            hidden_states = self.hidden_states[last_token_indices]
        else:
            hidden_states = hidden_states[last_token_indices]
243
244
245

        if isinstance(attn_metadata, TreeAttentionMetadata):
            # Draft using tree attention.
246
247
248
249
250
251
252
253
254
255
            draft_token_ids_list = self.propose_tree(
                batch_size=batch_size,
                logits=logits,
                positions=positions,
                hidden_states=hidden_states,
                common_attn_metadata=common_attn_metadata,
            )
            # [batch_size, num_tree_tokens]
            return torch.cat(draft_token_ids_list, dim=1)

256
        draft_token_ids = logits.argmax(dim=-1)
257

258
259
        if self.allowed_attn_types is not None and \
            not isinstance(attn_metadata, self.allowed_attn_types):
260
261
262
263
264
            raise ValueError(
                f"Unsupported attention metadata type for speculative "
                "decoding with num_speculative_tokens > 1: "
                f"{type(attn_metadata)}. Supported types are: "
                f"{self.allowed_attn_types}")
265

266
267
268
        # Generate the remaining draft tokens.
        draft_token_ids_list = [draft_token_ids]

269
        if self.use_cuda_graph and \
270
                batch_size <= self.cudagraph_batch_sizes[-1]:
271
272
273
            input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size)
        else:
            input_batch_size = batch_size
274
275
276
277
278
279
280

        common_attn_metadata.num_actual_tokens = batch_size
        common_attn_metadata.max_query_len = 1
        common_attn_metadata.query_start_loc = self.arange[:batch_size + 1]
        common_attn_metadata.query_start_loc_cpu = torch.from_numpy(
            self.token_arange_np[:batch_size + 1]).clone()
        for token_index in range(self.num_speculative_tokens - 1):
281
            # Update the inputs.
282
283
284
            # cast to int32 is crucial when eagle model is compiled.
            # tensor.argmax() returns int64 by default.
            input_ids = draft_token_ids_list[-1].int()
285
            positions += 1
286
287
288
289
290
291
292
293
294
295
296
297
298
299

            # NOTE(woosuk): We should handle the case where the draft model
            # generates tokens beyond the max model length. Since it is complex
            # to remove such requests from the batch, we keep them in the batch
            # but adjust the position ids and slot mappings to avoid the
            # out-of-range access during the model execution. The draft tokens
            # generated with this adjustment should be ignored.
            exceeds_max_model_len = positions >= self.max_model_len
            # Mask out the position ids that exceed the max model length.
            # Otherwise, we may get out-of-range error in RoPE.
            clamped_positions = torch.where(exceeds_max_model_len, 0,
                                            positions)

            # Increment the sequence lengths.
300
301
            common_attn_metadata.seq_lens += 1
            common_attn_metadata.seq_lens_cpu += 1
302
303
            # For the requests that exceed the max model length, we set the
            # sequence length to 1 to minimize their overheads in attention.
304
305
306
307
308
            common_attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len,
                                                       1)

            common_attn_metadata.num_computed_tokens_cpu = \
                common_attn_metadata.seq_lens_cpu - 1
309

310
            # Compute the slot mapping.
311
            block_numbers = clamped_positions // self.block_size
312
            block_ids = common_attn_metadata.block_table_tensor.gather(
313
                dim=1, index=block_numbers.view(-1, 1))
314
            block_ids = block_ids.view(-1)
315
316
317
            common_attn_metadata.slot_mapping = (
                block_ids * self.block_size +
                clamped_positions % self.block_size)
318
319
320
            # Mask out the slot mappings that exceed the max model length.
            # Otherwise, the KV cache will be inadvertently updated with the
            # padding tokens.
321
322
323
324
            common_attn_metadata.slot_mapping.masked_fill_(
                exceeds_max_model_len, PADDING_SLOT_ID)

            # Rebuild attention metadata
325
            attn_metadata = builder.build_for_drafting(  # type: ignore
326
327
                common_attn_metadata=common_attn_metadata,
                draft_index=token_index + 1)
328
329
            for layer_name in self.attn_layer_names:
                per_layer_attn_metadata[layer_name] = attn_metadata
330

331
332
333
            # copy inputs to buffer for cudagraph
            self.input_ids[:batch_size] = input_ids
            self.positions[:batch_size] = clamped_positions
334
            self.hidden_states[:batch_size] = hidden_states
335
336
337
338
339
340
341
342
            if self.is_multimodal_model:
                inputs_embeds = self.model.get_input_embeddings(input_ids)
                self.inputs_embeds[:batch_size] = inputs_embeds
                inputs_embeds = self.inputs_embeds[:input_batch_size]
                input_ids = None
            else:
                inputs_embeds = None
                input_ids = self.input_ids[:input_batch_size]
343

344
            # Run the model.
345
            with set_forward_context(per_layer_attn_metadata,
346
347
                                     self.vllm_config,
                                     num_tokens=input_batch_size):
348
                ret_hidden_states = self.model(
349
350
351
352
                    input_ids=input_ids,
                    positions=self.positions[:input_batch_size],
                    hidden_states=self.hidden_states[:input_batch_size],
                    inputs_embeds=inputs_embeds,
353
                )
354
                if self.method == "mtp":
355
356
357
358
                    last_hidden_states = ret_hidden_states
                    hidden_states = ret_hidden_states
                else:
                    last_hidden_states, hidden_states = ret_hidden_states
359
            hidden_states = hidden_states[:batch_size]
360
            logits = self.model.compute_logits(last_hidden_states[:batch_size])
361
            draft_token_ids = logits.argmax(dim=-1)
362
363
364
365
            draft_token_ids_list.append(draft_token_ids)

        # [batch_size, num_speculative_tokens]
        draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
366
        return draft_token_ids
367

368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
    def prepare_next_token_ids_cpu(
            self, sampled_token_ids: list[list[int]],
            requests: dict[str,
                           CachedRequestState], gpu_input_batch: InputBatch,
            num_scheduled_tokens: dict[str, int]) -> torch.Tensor:
        """
        This function is used to prepare the inputs for speculative decoding.
        It calculates the next token ids for each request based on the sampled
        token ids from the CPU. If a request has no sampled token ids (e.g.,
        during the initial decoding steps), it falls back to using the request
        state to get the next token id.
        """
        req_ids = gpu_input_batch.req_ids
        next_token_ids: list[int] = []
        for i, token_ids in enumerate(sampled_token_ids):
            if token_ids:
                # Common case.
                next_token_id = token_ids[-1]
            else:
                # Partial prefill (rare case).
                # Get the next token id from the request state.
                req_id = req_ids[i]
                req_state = requests[req_id]
                seq_len = (req_state.num_computed_tokens +
                           num_scheduled_tokens[req_id])
                next_token_id = req_state.get_token_id(seq_len)
            next_token_ids.append(next_token_id)
        next_token_ids = torch.tensor(next_token_ids,
                                      dtype=torch.int32,
                                      device=self.input_ids.device)
        return next_token_ids

    def prepare_next_token_ids_padded(self,
                               common_attn_metadata: CommonAttentionMetadata,
                               sampled_token_ids: torch.Tensor,
                               requests: dict[str, CachedRequestState],
                               gpu_input_batch: InputBatch,
                               discard_request_indices: torch.Tensor,
                               num_discarded_requests: int) -> \
                                tuple[torch.Tensor, torch.Tensor]:
        """
        This function is used to prepare the inputs for speculative decoding.
        It calculates the next token ids and the number of valid sampled tokens
        for each request, considering the "discarded" requests whose next token
        is not sampled and comes from `request.get_token_id()` instead.
        It also accounts for the rejected tokens in `sampled_token_ids`.
        This function must use device functions to operate on the inputs, and
        should not introduce any blocking CPU-GPU synchronization.
        """
        # TODO(Ben): Combine this into a custom fused kernel

        # Precompute get_token_id for when there is no valid next token
        num_reqs = gpu_input_batch.num_reqs
        self.backup_next_token_ids.np[:num_reqs] = np.array([
            requests[gpu_input_batch.req_ids[i]].get_token_id(
                common_attn_metadata.seq_lens_cpu[i].item())
            for i in range(num_reqs)
        ])
        self.backup_next_token_ids.copy_to_gpu(num_reqs)

        # Mask out the sampled tokens indices that should not be sampled.
        discard_sampled_tokens_req_indices = \
            discard_request_indices[:num_discarded_requests]

        valid_sampled_token_ids_gpu = sampled_token_ids.clone()
        valid_sampled_token_ids_gpu.index_fill_(
            0, discard_sampled_tokens_req_indices, -1)

        # Generate a mask for all valid tokens within those requests
        max_gen_len = sampled_token_ids.shape[-1]
        if max_gen_len == 1:
            valid_mask = torch.ones_like(valid_sampled_token_ids_gpu,
                                         dtype=torch.bool)
        else:
            valid_mask = (
                (valid_sampled_token_ids_gpu != -1) &
                (valid_sampled_token_ids_gpu < gpu_input_batch.vocab_size))

        # Count the number of valid tokens in each request
        valid_sampled_tokens_count = valid_mask.sum(dim=1)

        # Get the rightmost valid index per row
        last_valid_indices = valid_sampled_tokens_count - 1
        last_valid_indices_safe = torch.clamp(last_valid_indices, min=0)

        # Get last valid token from each row
        # (assume undefined state where there is no valid token)
        selected_tokens = torch.gather(
            valid_sampled_token_ids_gpu, 1,
            last_valid_indices_safe.unsqueeze(1)).squeeze(1)

        # Use last token if valid, pre-computed backup if not
        batch_size = valid_sampled_token_ids_gpu.shape[0]
        next_token_ids = torch.where(
            last_valid_indices != -1, selected_tokens,
            self.backup_next_token_ids.gpu[:batch_size])

        return next_token_ids, valid_sampled_tokens_count

    def prepare_inputs_padded(self,
                                common_attn_metadata: CommonAttentionMetadata,
                                spec_decode_metadata: SpecDecodeMetadata,
                                valid_sampled_tokens_count: torch.Tensor) -> \
                    tuple[CommonAttentionMetadata, torch.Tensor, torch.Tensor]:
        """
        This function is used to prepare the inputs for speculative decoding
        It updates the common_attn_metadata for speculative decoding,
        but does not consider the rejected tokens. Instead, all tokens
        are included as inputs to the speculator, with the rejected tokens
        used as padding and filtered out later by `token_indices_to_sample`.
        No blocking CPU operations should be introduced in this function.
        """
        num_draft_tokens_gpu = torch.cat([
            spec_decode_metadata.cu_num_draft_tokens[0:1],
            spec_decode_metadata.cu_num_draft_tokens[1:] -
            spec_decode_metadata.cu_num_draft_tokens[:-1]
        ])

        num_rejected_tokens_gpu = torch.where(
            num_draft_tokens_gpu > 0,
            num_draft_tokens_gpu + 1 - valid_sampled_tokens_count,
            torch.zeros_like(num_draft_tokens_gpu))

        query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu

        new_query_len_per_req = (query_start_loc_cpu[1:] -
                                 query_start_loc_cpu[:-1])

        total_num_tokens = query_start_loc_cpu[-1].item()
        token_indices = self.arange[:total_num_tokens]

        spec_common_attn_metadata = CommonAttentionMetadata(
            query_start_loc=common_attn_metadata.query_start_loc,
            seq_lens=common_attn_metadata.seq_lens,
            query_start_loc_cpu=query_start_loc_cpu,
            seq_lens_cpu=common_attn_metadata.seq_lens_cpu,
            num_computed_tokens_cpu=common_attn_metadata.
            num_computed_tokens_cpu,
            num_reqs=common_attn_metadata.num_reqs,
            num_actual_tokens=total_num_tokens,
            max_query_len=new_query_len_per_req.max().item(),
            max_seq_len=common_attn_metadata.seq_lens_cpu.max().item(),
            block_table_tensor=common_attn_metadata.block_table_tensor,
            slot_mapping=common_attn_metadata.slot_mapping[token_indices],
            causal=True,
        )

        token_indices_to_sample = common_attn_metadata.query_start_loc[1:] - 1 \
            - num_rejected_tokens_gpu

        return spec_common_attn_metadata, token_indices, token_indices_to_sample

520
521
522
523
524
525
526
527
528
529
530
    def propose_tree(
        self,
        batch_size: int,
        # [num_tokens, vocab_size]
        logits: torch.Tensor,
        # [num_tokens]
        positions: torch.Tensor,
        # [num_tokens, hidden_size]
        hidden_states: torch.Tensor,
        common_attn_metadata: CommonAttentionMetadata,
    ) -> list[torch.Tensor]:
531
        tree_attn_metadata_builder = \
532
            self.runner.attn_groups[0][0].get_metadata_builder()
533
534
535
        assert isinstance(tree_attn_metadata_builder,
                          TreeAttentionMetadataBuilder)

536
        total_num_drafts = self.cu_drafts_per_level[0]
537
538
        level_num_drafts = total_num_drafts
        # Sample a draft token for each child at the tree root level.
539
        num_children = self.child_drafts_per_level[0]
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
        if num_children == 1:
            draft_token_ids = logits.argmax(dim=-1).view(batch_size, -1)
        else:
            draft_token_ids = torch.topk(logits, num_children,
                                         dim=-1).indices.view(batch_size, -1)
        draft_token_ids_list = [draft_token_ids]
        draft_hidden_states = hidden_states.view(batch_size, 1, -1)

        # Initialize empty tensors for concatenation with the level outputs.
        tree_input_ids = torch.empty(0,
                                     device=self.input_ids.device,
                                     dtype=self.input_ids.dtype)
        tree_positions = torch.empty(0,
                                     device=self.positions.device,
                                     dtype=self.positions.dtype)
        tree_hidden_states = torch.empty(0,
                                         device=self.hidden_states.device,
                                         dtype=self.hidden_states.dtype)
        # Precompute the draft token positions.
        flattened_draft_positions = (
            positions.view(batch_size, -1) +
            self.tree_draft_pos_offsets[:batch_size, :])
        tree_depth = len(self.cu_drafts_per_level)
563
        for level in range(tree_depth - 1):
564
565
566
567
568
569
            # Get draft positions for RoPE.
            draft_positions = positions + (level + 1)
            exceeds_max_model_len = (positions +
                                     total_num_drafts) >= self.max_model_len
            # Mask out the position ids that exceed the max model length.
            # Otherwise, we may get out-of-range error in RoPE.
570
            draft_positions = torch.where(
571
572
573
                exceeds_max_model_len,
                0,
                draft_positions,
574
575
            ).view(batch_size, -1)

576
577
            if level_num_drafts > 1:
                # Repeat the positions for each draft at this level.
578
579
                draft_positions = draft_positions.repeat_interleave(
                    level_num_drafts, dim=1)
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595

            if num_children > 1:
                # Repeat draft hidden states for each child.
                draft_hidden_states = draft_hidden_states.repeat_interleave(
                    num_children, dim=1)

            # Concatenate the draft tokens, positions, and hidden states.
            tree_input_ids = torch.cat([tree_input_ids, draft_token_ids],
                                       dim=1)
            tree_positions = torch.cat([tree_positions, draft_positions],
                                       dim=1)
            tree_hidden_states = torch.cat(
                [tree_hidden_states, draft_hidden_states], dim=1)

            # Build new attention metadata for the next level of drafts.
            # This is necessary to support tree attention.
596
            query_len = total_num_drafts
597
598
599
600
601
602
603
604
605
            common_attn_metadata = replace(
                common_attn_metadata,
                query_start_loc=query_len * self.arange[:batch_size + 1],
                seq_lens=common_attn_metadata.seq_lens + level_num_drafts,
                num_actual_tokens=batch_size * query_len,
                max_query_len=query_len,
            )
            attn_metadata = tree_attn_metadata_builder.build_for_drafting(
                common_attn_metadata=common_attn_metadata,
606
                draft_index=level + 1,
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
            )

            # Apply new attention metadata to all layers.
            per_layer_attn_metadata = {}
            for layer_name in self.attn_layer_names:
                per_layer_attn_metadata[layer_name] = attn_metadata

            # Consider max model length.
            attn_metadata.max_seq_len = min(attn_metadata.max_seq_len,
                                            self.max_model_len)
            # For the requests that exceed the max model length, we set the
            # sequence length to 1 to minimize their overheads in attention.
            attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1)

            # Compute the slot mapping.
            query_positions = flattened_draft_positions[:, level:level +
                                                        query_len]
            block_numbers = query_positions // self.block_size
            block_ids = attn_metadata.block_table.gather(dim=1,
                                                         index=block_numbers)
            slot_mapping = (block_ids * self.block_size +
                            query_positions % self.block_size)
            # Mask out the slot mappings that exceed the max model length.
            # Otherwise, the KV cache will be inadvertently updated with the
            # padding tokens.
            slot_mapping[exceeds_max_model_len] = PADDING_SLOT_ID
            attn_metadata.slot_mapping = slot_mapping.view(-1)

            # Copy inputs to buffer for cudagraph.
            num_tokens = attn_metadata.num_actual_tokens
            input_ids = tree_input_ids.view(-1)
            self.input_ids[:num_tokens] = input_ids
            self.positions[:num_tokens] = tree_positions.view(-1)
            self.hidden_states[:num_tokens] = tree_hidden_states.view(
                num_tokens, -1)

            if self.use_cuda_graph and \
644
                    num_tokens <= self.cudagraph_batch_sizes[-1]:
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
                num_input_tokens = self.vllm_config.pad_for_cudagraph(
                    num_tokens)
            else:
                num_input_tokens = num_tokens
            # Run the model.
            with set_forward_context(per_layer_attn_metadata,
                                     self.vllm_config,
                                     num_tokens=num_input_tokens):
                last_hidden_states, hidden_states = self.model(
                    input_ids=self.input_ids[:num_input_tokens],
                    positions=self.positions[:num_input_tokens],
                    hidden_states=self.hidden_states[:num_input_tokens],
                    inputs_embeds=None,
                )

            # Get the output hidden states for the draft tokens.
            draft_hidden_states = hidden_states[:num_tokens].view(
                batch_size, query_len, -1)[:, -level_num_drafts:]
            draft_last_hidden_states = last_hidden_states[:num_tokens].view(
                batch_size, query_len, -1)[:, -level_num_drafts:]

            # Get the output logits for the draft tokens.
            logits = self.model.compute_logits(
                draft_last_hidden_states.reshape(batch_size * level_num_drafts,
669
                                                 -1))
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686

            # Sample a draft token for each child at the next tree level.
            num_children = self.child_drafts_per_level[level + 1]
            if num_children == 1:
                draft_token_ids = logits.argmax(dim=-1).view(batch_size, -1)
            else:
                draft_token_ids = torch.topk(logits, num_children,
                                             dim=-1).indices.view(
                                                 batch_size, -1)
            draft_token_ids_list.append(draft_token_ids)

            # Update the # drafts counters for the next tree level.
            level_num_drafts = self.cu_drafts_per_level[level +
                                                        1] - total_num_drafts
            total_num_drafts = self.cu_drafts_per_level[level + 1]
        return draft_token_ids_list

687
    def prepare_inputs(
688
689
        self,
        common_attn_metadata: CommonAttentionMetadata,
690
691
        sampled_token_ids: list[list[int]],
        num_draft_tokens: list[int],
692
693
    ) -> tuple[CommonAttentionMetadata, torch.Tensor]:
        """
694
        This function is used to prepare the inputs for speculative decoding.
695
696
697
698
699
700
        It updates to the common_attn_metadata to account for the rejected
        tokens (and newly sampled tokens). It also returns the token indices
        of the tokens that should be fed to the speculator.
        """
        # E.g.
        #  common_attn_metadata.query_start_loc{_cpu}:
701
        #       [0, q1, q1 + q2, q1 + q2 + q3]
702
703
704
705
706
707
        #  common_attn_metadata.seq_lens{_cpu}: [s1, s2, s3]
        #  num_rejected_tokens: [n1, n2, n3]
        # This function computes the intermediate values:
        #  num_tokens_per_req: [q1 - n1, q2 - n2, q3 - n3]
        # And returns:
        #  common_attn_metadata.query_start_loc{_cpu}:
708
        #       [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
709
        #  common_attn_metadata.seq_lens{_cpu}:
710
        #       [s1 - n1 + 1, s2 - n2 + 1, s3 - n3 + 1]
711
        #  token_indices: [0, 1, ..., q1 - n1 - 1,
712
713
        #                 q1, q1 + 1, ..., q1 + q2 - n2 - 1,
        #                 q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1]
714

715
716
717
718
719
720
721
        num_rejected_tokens = [
            n + 1 - len(sampled_token_ids[i]) if n > 0 else 0
            for i, n in enumerate(num_draft_tokens)
        ]
        num_rejected_tokens = torch.tensor(num_rejected_tokens,
                                           dtype=torch.int32)

722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
        device = common_attn_metadata.query_start_loc.device
        query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
        new_seq_lens_cpu = common_attn_metadata.seq_lens_cpu \
            - num_rejected_tokens

        # [0, q1, q1 + q2, q1 + q2 + q3] -> [q1, q2, q3]
        new_query_len_per_req = (query_start_loc_cpu[1:] -
                                 query_start_loc_cpu[:-1])
        # [q1, q2, q3] -> [q1 - n1, q2 - n2, q3 - n3]
        new_num_tokens_per_req = new_query_len_per_req - num_rejected_tokens
        new_num_tokens_per_req_np = new_num_tokens_per_req.numpy()

        # [q1 - n1, q2 - n2, q3 - n3] ->
        # [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
        new_query_start_loc_cpu = torch.zeros(
            query_start_loc_cpu.shape,
738
            dtype=torch.int32,
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
            pin_memory=is_pin_memory_available())
        new_query_start_loc_np = new_query_start_loc_cpu.numpy()
        np.cumsum(new_num_tokens_per_req_np, out=new_query_start_loc_np[1:])

        total_num_tokens = new_query_start_loc_np[-1]
        # Example assuming num_tokens_per_req_np = [2, 4, 3]
        # this implies that `new_query_start_locs` is:
        # [0, 2, 6, 9] ->
        # [0, 0, 2, 2, 2, 2, 6, 6, 6]
        #  _r1_  ____r2____  ___r3__
        new_query_start_locs_expanded = np.repeat(new_query_start_loc_np[:-1],
                                                  new_num_tokens_per_req_np)
        # [0, 1, 2, 3, 4, 5, 6, 7, 8] ->
        # [0, 1, 0, 1, 2, 3, 0, 1, 2]
        #  _r1_  ____r2____  ___r3__
        token_offests = self.token_arange_np[:total_num_tokens] \
            - new_query_start_locs_expanded

        # Expand starting positions to match token pattern
        # [0, q1, q1 + q2] ->
        # [0, 0, q1, q1, q1, q1, q1 + q2, q1 + q2, q1 + q2]
        #  _r1_  _____r2_______  ___________r3____________
        old_query_start_locs_expanded = np.repeat(
            query_start_loc_cpu[:-1].numpy(), new_num_tokens_per_req_np)
        # Final token indices are:
764
765
766
        # [0, 1,                                // req 1
        #  q1 + 0, q1 + 1, q1 + 2, q1 + 3,       // req 2
        #  q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
        token_indices_np = token_offests + old_query_start_locs_expanded
        token_indices = torch.from_numpy(token_indices_np).to(
            device, non_blocking=True)

        spec_common_attn_metadata = CommonAttentionMetadata(
            query_start_loc=new_query_start_loc_cpu.to(device,
                                                       non_blocking=True),
            seq_lens=new_seq_lens_cpu.to(device, non_blocking=True),
            query_start_loc_cpu=new_query_start_loc_cpu,
            seq_lens_cpu=new_seq_lens_cpu,
            num_computed_tokens_cpu=common_attn_metadata.
            num_computed_tokens_cpu,
            num_reqs=common_attn_metadata.num_reqs,
            num_actual_tokens=total_num_tokens,
            max_query_len=new_query_len_per_req.max().item(),
782
            max_seq_len=new_seq_lens_cpu.max().item(),
783
784
            block_table_tensor=common_attn_metadata.block_table_tensor,
            slot_mapping=common_attn_metadata.slot_mapping[token_indices],
785
            causal=True,
786
        )
787
788

        return spec_common_attn_metadata, token_indices
789
790

    def load_model(self, target_model: nn.Module) -> None:
791
792
        draft_model_config = \
            self.vllm_config.speculative_config.draft_model_config
793
794
        target_attn_layer_names = set(
            get_layers_from_vllm_config(self.vllm_config, Attention).keys())
795

796
797
798
799
        from vllm.compilation.backends import set_model_tag
        with set_model_tag("eagle_head"):
            self.model = get_model(vllm_config=self.vllm_config,
                                   model_config=draft_model_config)
800

801
802
803
        draft_attn_layer_names = (
            get_layers_from_vllm_config(self.vllm_config, Attention).keys() -
            target_attn_layer_names)
804
805

        self.attn_layer_names = list(draft_attn_layer_names)
806

807
808
809
810
811
812
813
814
815
816
817
818
819
820
        if self.is_multimodal_model:
            # Even if the target model is multimodal, we can also use
            # text-only draft models
            try:
                dummy_input_ids = torch.tensor([[1]],
                                               device=self.input_ids.device)
                self.model.get_input_embeddings(dummy_input_ids,
                                                multimodal_embeddings=None)
            except (NotImplementedError, AttributeError, TypeError):
                logger.warning(
                    "Draft model does not support multimodal inputs, "
                    "falling back to text-only mode")
                self.is_multimodal_model = False

821
822
823
824
825
826
827
        if supports_multimodal(target_model):
            # handle multimodality
            self.model.config.image_token_index = (
                target_model.config.image_token_index)
            target_language_model = target_model.get_language_model()
        else:
            target_language_model = target_model
828
        # share embed_tokens with the target model if needed
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
        if get_pp_group().world_size == 1:
            if hasattr(target_language_model.model, 'embed_tokens'):
                target_embed_tokens = target_language_model.model.embed_tokens
            elif hasattr(target_language_model.model, 'embedding'):
                target_embed_tokens = target_language_model.model.embedding
            else:
                raise AttributeError(
                    "Target model does not have 'embed_tokens' or 'embedding' "
                    "attribute")

            # Check if shapes match and we found the embedding
            eagle_shape = self.model.model.embed_tokens.weight.shape
            target_shape = target_embed_tokens.weight.shape
            if eagle_shape == target_shape:
                logger.info(
                    "Assuming the EAGLE head shares the same vocab embedding"
                    " with the target model.")
                del self.model.model.embed_tokens
                self.model.model.embed_tokens = target_embed_tokens
            else:
                logger.info(
                    "The EAGLE head's vocab embedding will be loaded separately"
                    " from the target model.")
852
        else:
853
            logger.info(
854
855
                "The EAGLE head's vocab embedding will be loaded separately"
                " from the target model.")
856
857
858
859

        # share lm_head with the target model if needed
        # some model definition do not define lm_head explicitly
        # and reuse embed_tokens for lm_head, e.g., CohereForCausalLM
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
        if self.vllm_config.speculative_config.method != "eagle3":
            if hasattr(target_language_model, "lm_head"):
                logger.info(
                    "Loading EAGLE LM head weights from the target model.")
                self.model.lm_head = target_language_model.lm_head
        else:
            if (hasattr(self.model, "lm_head")
                    and hasattr(target_language_model, "lm_head")
                    and self.model.lm_head.weight.shape
                    == target_language_model.lm_head.weight.shape):
                logger.info("Assuming the EAGLE head shares the same lm_head"
                            " with the target model.")
                del self.model.lm_head
                self.model.lm_head = target_language_model.lm_head
            else:
                logger.info(
                    "The EAGLE head's lm_head will be loaded separately"
                    " from the target model.")
878

879
880
881
882
883
884
885
    @torch.inference_mode()
    def dummy_run(
        self,
        num_tokens: int,
    ) -> None:
        with set_forward_context(None, self.vllm_config,
                                 num_tokens=num_tokens):
886
887
888
889
890
891
892
            if self.is_multimodal_model:
                input_ids = None
                inputs_embeds = self.inputs_embeds[:num_tokens]
            else:
                input_ids = self.input_ids[:num_tokens]
                inputs_embeds = None

893
            self.model(
894
895
896
897
                input_ids=input_ids,
                positions=self.positions[:num_tokens],
                hidden_states=self.hidden_states[:num_tokens],
                inputs_embeds=inputs_embeds,
898
            )
899

900
901
902
    def _get_attention_metadata_builder(
            self) -> list[AttentionMetadataBuilder]:
        """Find and return the attention metadata builders for EAGLE layers.
903

904
905
        Returns:
            The metadata builders for EAGLE layers.
906

907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
        Raises:
            AssertionError: If no metadata builders are found for EAGLE layers.
        """
        builder = None
        chosen_layer = self.attn_layer_names[0]

        for kv_cache_group in self.runner.attn_groups:
            for attn_group in kv_cache_group:
                if chosen_layer in attn_group.layer_names:
                    builder = attn_group.get_metadata_builder()
                    break
            if builder is not None:
                break

        assert builder is not None, (
            "Failed to find attention metadata builder for EAGLE layers.")
        return builder

925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
    def validate_same_kv_cache_group(self,
                                     kv_cache_config: KVCacheConfig) -> None:
        """
        Validate that all eagle layers belong to the same KVCacheGroup.
        Need this assumption to ensure all eagle layers can use the
        same AttentionMetadata.
        May extend to multiple AttentionMetadata in the future.
        """
        kv_cache_groups: dict[str, int] = {}
        for id, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups):
            for layer_name in kv_cache_group.layer_names:
                kv_cache_groups[layer_name] = id
        assert len(
            set([
                kv_cache_groups[layer_name]
                for layer_name in self.attn_layer_names
            ])
        ) == 1, "All eagle layers should belong to the same kv cache group"

944

945
946
947
948
# NOTE(woosuk): Currently, the below code is not used and we always use argmax
# to sample the draft tokens. We will use this after we find a way to manage
# the draft prob tensor.
# Refer to https://github.com/vllm-project/vllm/pull/16899 for the details.
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
# FIXME(woosuk): The logic here is duplicated with the main sampling code.
# We should refactor this to reuse the same sampling implementation.
def compute_probs_and_sample_next_token(
    logits: torch.Tensor,
    sampling_metadata: SamplingMetadata,
) -> tuple[torch.Tensor, torch.Tensor]:
    if sampling_metadata.all_greedy:
        # For greedy requests, draft_probs is not used in rejection sampling.
        # Therefore, we can just return the logits.
        probs = logits
        next_token_ids = logits.argmax(dim=-1)
        return next_token_ids, probs

    is_greedy = sampling_metadata.temperature == -1
    temperature = torch.where(is_greedy, 1.0, sampling_metadata.temperature)
    logits.div_(temperature.view(-1, 1))
    probs = logits.softmax(dim=-1, dtype=torch.float32)

    # NOTE(woosuk): Currently, we ignore most of the sampling parameters in
    # generating the draft tokens. We only use the temperature. While this
    # could degrade the acceptance rate, it does not affect the distribution
    # of the generated tokens after rejection sampling.

    # TODO(woosuk): Consider seeds.
    q = torch.empty_like(probs)
    q.exponential_()
975
976
977
    # NOTE(woosuk): We shouldn't use `probs.div_(q)` because the draft_probs
    # will be used later for rejection sampling.
    next_token_ids = probs.div(q).argmax(dim=-1).view(-1)
978
979
980
981
982
983
984
985
    if not sampling_metadata.all_random:
        greedy_token_ids = probs.argmax(dim=-1)
        next_token_ids = torch.where(
            is_greedy,
            greedy_token_ids,
            next_token_ids,
        )
    return next_token_ids, probs