eagle.py 47.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
import ast
from dataclasses import replace
5
from importlib.util import find_spec
6
from typing import Optional
7

8
import numpy as np
9
10
11
import torch
import torch.nn as nn

12
from vllm.config import CompilationLevel, VllmConfig, get_layers_from_vllm_config
13
from vllm.distributed.parallel_state import get_pp_group
14
from vllm.forward_context import set_forward_context
15
from vllm.logger import init_logger
16
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
17
from vllm.model_executor.model_loader import get_model
18
from vllm.model_executor.models import supports_multimodal
19
from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache
20
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
21
from vllm.multimodal import MULTIMODAL_REGISTRY
22
from vllm.platforms import current_platform
23
from vllm.utils import is_pin_memory_available
24
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
25
26
27
28
from vllm.v1.attention.backends.tree_attn import (
    TreeAttentionMetadata,
    TreeAttentionMetadataBuilder,
)
29
from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata
30
31
32
33
from vllm.v1.attention.backends.utils import (
    AttentionMetadataBuilder,
    CommonAttentionMetadata,
)
34
from vllm.v1.kv_cache_interface import KVCacheConfig
35
from vllm.v1.sample.metadata import SamplingMetadata
36
37
38
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.utils import CpuGpuBuffer
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
39
from vllm.v1.worker.ubatching import dbo_current_ubatch_id
40

41
42
logger = init_logger(__name__)

43
44
PADDING_SLOT_ID = -1

45
46
47
48
49
50

class EagleProposer:
    def __init__(
        self,
        vllm_config: VllmConfig,
        device: torch.device,
Jiayi Yao's avatar
Jiayi Yao committed
51
        runner=None,
52
53
    ):
        self.vllm_config = vllm_config
54
        self.speculative_config = vllm_config.speculative_config
55
        assert self.speculative_config is not None
56
57
        self.draft_model_config = self.speculative_config.draft_model_config
        self.method = self.speculative_config.method
58

Jiayi Yao's avatar
Jiayi Yao committed
59
        self.runner = runner
60
        self.device = device
61
        self.dtype = vllm_config.model_config.dtype
62
63
        self.max_model_len = vllm_config.model_config.max_model_len
        self.block_size = vllm_config.cache_config.block_size
64
65
        self.num_speculative_tokens = self.speculative_config.num_speculative_tokens
        self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
66
        self.token_arange_np = np.arange(self.max_num_tokens)
67
68
69
70
        # We need to get the hidden size from the draft model config because
        # the draft model's hidden size can be different from the target model's
        # hidden size (e.g., Llama 3.3 70B).
        self.hidden_size = self.draft_model_config.get_hidden_size()
71

72
73
74
        # Multi-modal data support
        self.mm_registry = MULTIMODAL_REGISTRY
        self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
75
76
            vllm_config.model_config
        )
77

78
        self.attn_metadata_builder: Optional[AttentionMetadataBuilder] = None
79
        self.draft_indexer_metadata_builder: Optional[AttentionMetadataBuilder] = None
80
81
        self.attn_layer_names: list[str] = []
        self.indexer_layer_names: list[str] = []
82

83
84
85
86
87
88
89
90
91
92
93
        self.use_cuda_graph = (
            not current_platform.is_xpu()
            and self.vllm_config.compilation_config.level == CompilationLevel.PIECEWISE
            and not self.vllm_config.model_config.enforce_eager
            and not self.speculative_config.enforce_eager
        )
        self.cudagraph_batch_sizes = (
            list(reversed(self.vllm_config.compilation_config.cudagraph_capture_sizes))
            if self.use_cuda_graph
            else []
        )
94
95

        # persistent buffers for cuda graph
96
97
98
        self.input_ids = torch.zeros(
            self.max_num_tokens, dtype=torch.int32, device=device
        )
99
100
101
        self.uses_mrope = self.vllm_config.model_config.uses_mrope
        if self.uses_mrope:
            # M-RoPE need (3, max_num_tokens)
102
103
104
            self.mrope_positions = torch.zeros(
                (3, self.max_num_tokens), dtype=torch.int64, device=device
            )
105
106
        else:
            # RoPE need (max_num_tokens,)
107
108
109
            self.positions = torch.zeros(
                self.max_num_tokens, dtype=torch.int64, device=device
            )
110
        self.hidden_states = torch.zeros(
111
112
            (self.max_num_tokens, self.hidden_size), dtype=self.dtype, device=device
        )
113

114
115
        # We need +1 here because the arange is used to set query_start_loc,
        # which has one more element than batch_size.
116
        max_batch_size = vllm_config.scheduler_config.max_num_seqs
117
        max_num_slots_for_arange = max(max_batch_size + 1, self.max_num_tokens)
118
119
120
        self.arange = torch.arange(
            max_num_slots_for_arange, device=device, dtype=torch.int32
        )
121

122
        self.inputs_embeds = torch.zeros(
123
124
            (self.max_num_tokens, self.hidden_size), dtype=self.dtype, device=device
        )
125

126
127
128
129
130
        self.backup_next_token_ids = CpuGpuBuffer(
            max_batch_size,
            dtype=torch.int32,
            pin_memory=is_pin_memory_available(),
            device=device,
131
132
            with_numpy=True,
        )
133

134
        # Determine allowed attention backends once during initialization.
135
        self.allowed_attn_types: Optional[tuple] = None
136
137
138
139
140
        if current_platform.is_rocm():
            rocm_types = [TritonAttentionMetadata, FlashAttentionMetadata]
            # vllm.v1.attention.backends.rocm_aiter_fa is an optional backend
            if find_spec("vllm.v1.attention.backends.rocm_aiter_fa"):
                from vllm.v1.attention.backends.rocm_aiter_fa import (
141
142
143
                    AiterFlashAttentionMetadata,
                )

144
145
146
                rocm_types.append(AiterFlashAttentionMetadata)
            self.allowed_attn_types = tuple(rocm_types)

147
148
        # Parse the speculative token tree.
        spec_token_tree = self.speculative_config.speculative_token_tree
149
        self.tree_choices: list[tuple[int, ...]] = ast.literal_eval(spec_token_tree)
150
151
152
153
154
155
156
157
        tree_depth = len(self.tree_choices[-1])
        # Precompute per-level properties of the tree.
        num_drafts_per_level = [0] * tree_depth
        for node in self.tree_choices:
            num_drafts_per_level[len(node) - 1] += 1
        self.cu_drafts_per_level = [num_drafts_per_level[0]]
        self.child_drafts_per_level = [num_drafts_per_level[0]]
        for level in range(1, tree_depth):
158
159
160
161
162
163
            self.cu_drafts_per_level.append(
                self.cu_drafts_per_level[-1] + num_drafts_per_level[level]
            )
            self.child_drafts_per_level.append(
                num_drafts_per_level[level] // num_drafts_per_level[level - 1]
            )
164
165
166
167
168
169
170
171
        # Precompute draft position offsets in flattened tree.
        self.tree_draft_pos_offsets = torch.arange(
            1,
            len(self.tree_choices) + 1,
            device=device,
            dtype=torch.int32,
        ).repeat(max_batch_size, 1)

172
173
174
175
176
177
178
179
180
181
182
    def _get_positions(self, num_tokens: int):
        if self.uses_mrope:
            return self.mrope_positions[:, :num_tokens]
        return self.positions[:num_tokens]

    def _set_positions(self, num_tokens: int, positions: torch.Tensor):
        if self.uses_mrope:
            self.mrope_positions[:, :num_tokens] = positions
        else:
            self.positions[:num_tokens] = positions

183
184
185
186
    def propose(
        self,
        # [num_tokens]
        target_token_ids: torch.Tensor,
187
        # [num_tokens] or [3, num_tokens] when M-RoPE is enabled
188
189
190
191
192
        target_positions: torch.Tensor,
        # [num_tokens, hidden_size]
        target_hidden_states: torch.Tensor,
        # [batch_size]
        next_token_ids: torch.Tensor,
193
        last_token_indices: Optional[torch.Tensor],
194
        common_attn_metadata: CommonAttentionMetadata,
195
        sampling_metadata: SamplingMetadata,
196
        mm_embed_inputs: Optional[tuple[list[torch.Tensor], torch.Tensor]] = None,
197
    ) -> torch.Tensor:
198
199
        num_tokens = target_token_ids.shape[0]
        batch_size = next_token_ids.shape[0]
200
201
202

        if last_token_indices is None:
            last_token_indices = common_attn_metadata.query_start_loc[1:] - 1
203

204
205
206
        if self.method == "eagle3":
            assert isinstance(self.model, Eagle3LlamaForCausalLM)
            target_hidden_states = self.model.combine_hidden_states(
207
208
                target_hidden_states
            )
209
            assert target_hidden_states.shape[-1] == self.hidden_size
210
211
        # Shift the input ids by one token.
        # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
212
        self.input_ids[: num_tokens - 1] = target_token_ids[1:]
213
214
        # Replace the last token with the next token.
        # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
215
        self.input_ids[last_token_indices] = next_token_ids
216

217
        assert self.runner is not None
Jiayi Yao's avatar
Jiayi Yao committed
218

219
220
        # FIXME: need to consider multiple kv_cache_groups
        ubatch_id = dbo_current_ubatch_id()
221
222
223
        attn_metadata_builder = self.runner.attn_groups[0][0].metadata_builders[
            ubatch_id
        ]
224
        attn_metadata = attn_metadata_builder.build_for_drafting(
225
226
            common_attn_metadata=common_attn_metadata, draft_index=0
        )
227
228
229
230
231
232
        # FIXME: support hybrid kv for draft model (remove separate indexer)
        if self.draft_indexer_metadata_builder:
            draft_indexer_metadata = (
                self.draft_indexer_metadata_builder.build_for_drafting(
                    common_attn_metadata=common_attn_metadata,
                    draft_index=0,
233
234
                )
            )
235
236
        else:
            draft_indexer_metadata = None
237
238
239
240
241
        # At this moment, we assume all eagle layers belong to the same KV
        # cache group, thus using the same attention metadata.
        per_layer_attn_metadata = {}
        for layer_name in self.attn_layer_names:
            per_layer_attn_metadata[layer_name] = attn_metadata
242
243
244
245
        for layer_name in self.indexer_layer_names:
            assert draft_indexer_metadata is not None
            per_layer_attn_metadata[layer_name] = draft_indexer_metadata

246
        if self.use_cuda_graph and num_tokens <= self.cudagraph_batch_sizes[-1]:
247
248
249
250
            num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
        else:
            num_input_tokens = num_tokens
        # copy inputs to buffer for cudagraph
251
        self._set_positions(num_tokens, target_positions)
252
        self.hidden_states[:num_tokens] = target_hidden_states
253
254
255
256
257
258
259
260

        if self.supports_mm_inputs:
            mm_embeds, is_mm_embed = mm_embed_inputs or (None, None)

            self.inputs_embeds[:num_tokens] = self.model.get_input_embeddings(
                self.input_ids[:num_tokens],
                multimodal_embeddings=mm_embeds,
                is_multimodal=is_mm_embed,
261
            )
262

263
            input_ids = None
264
            inputs_embeds = self.inputs_embeds[:num_input_tokens]
265
266
        else:
            input_ids = self.input_ids[:num_input_tokens]
267
            inputs_embeds = None
268

269
270
271
        with set_forward_context(
            per_layer_attn_metadata, self.vllm_config, num_tokens=num_input_tokens
        ):
Jiayi Yao's avatar
Jiayi Yao committed
272
            ret_hidden_states = self.model(
273
                input_ids=input_ids,
274
                positions=self._get_positions(num_input_tokens),
275
276
                hidden_states=self.hidden_states[:num_input_tokens],
                inputs_embeds=inputs_embeds,
277
            )
278
            if self.method == "mtp":
Jiayi Yao's avatar
Jiayi Yao committed
279
                last_hidden_states = ret_hidden_states
280
                hidden_states = last_hidden_states
Jiayi Yao's avatar
Jiayi Yao committed
281
282
            else:
                last_hidden_states, hidden_states = ret_hidden_states
283
        sample_hidden_states = last_hidden_states[last_token_indices]
284
        logits = self.model.compute_logits(sample_hidden_states)
285
286
287
288
289
290

        # Early exit if there is only one draft token to be generated.
        if self.num_speculative_tokens == 1:
            draft_token_ids = logits.argmax(dim=-1)
            return draft_token_ids.view(-1, 1)

291
292
293
294
        if self.uses_mrope:
            positions = target_positions[:, last_token_indices]
        else:
            positions = target_positions[last_token_indices]
XuruiYang's avatar
XuruiYang committed
295
296
297
298
        if self.method in ("deepseek_mtp", "ernie_mtp", "longcat_flash_mtp"):
            hidden_states = self.hidden_states[last_token_indices]
        else:
            hidden_states = hidden_states[last_token_indices]
299
300
301

        if isinstance(attn_metadata, TreeAttentionMetadata):
            # Draft using tree attention.
302
303
304
305
306
307
308
309
310
311
            draft_token_ids_list = self.propose_tree(
                batch_size=batch_size,
                logits=logits,
                positions=positions,
                hidden_states=hidden_states,
                common_attn_metadata=common_attn_metadata,
            )
            # [batch_size, num_tree_tokens]
            return torch.cat(draft_token_ids_list, dim=1)

312
        draft_token_ids = logits.argmax(dim=-1)
313

314
315
316
        if self.allowed_attn_types is not None and not isinstance(
            attn_metadata, self.allowed_attn_types
        ):
317
318
319
320
            raise ValueError(
                f"Unsupported attention metadata type for speculative "
                "decoding with num_speculative_tokens > 1: "
                f"{type(attn_metadata)}. Supported types are: "
321
322
                f"{self.allowed_attn_types}"
            )
323

324
325
326
        # Generate the remaining draft tokens.
        draft_token_ids_list = [draft_token_ids]

327
        if self.use_cuda_graph and batch_size <= self.cudagraph_batch_sizes[-1]:
328
329
330
            input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size)
        else:
            input_batch_size = batch_size
331
332
333

        common_attn_metadata.num_actual_tokens = batch_size
        common_attn_metadata.max_query_len = 1
334
        common_attn_metadata.query_start_loc = self.arange[: batch_size + 1]
335
        common_attn_metadata.query_start_loc_cpu = torch.from_numpy(
336
337
            self.token_arange_np[: batch_size + 1]
        ).clone()
338
        for token_index in range(self.num_speculative_tokens - 1):
339
            # Update the inputs.
340
341
342
            # cast to int32 is crucial when eagle model is compiled.
            # tensor.argmax() returns int64 by default.
            input_ids = draft_token_ids_list[-1].int()
343
344
345
346
347
348
349
350
351
352
353
354
355
            if self.uses_mrope:
                positions += 1
                # NOTE(woosuk): We should handle the case where the draft model
                # generates tokens beyond the max model length.
                # Since it is complex to remove such requests from the batch,
                # we keep them in the batch but adjust the position ids
                # and slot mappings to avoid the
                # out-of-range access during the model execution.
                # The draft tokens generated with this adjustment
                # should be ignored.
                exceeds_max_model_len = positions[0] >= self.max_model_len
                # Mask out the position ids that exceed the max model length.
                # Otherwise, we may get out-of-range error in RoPE.
356
357
358
359
360
                clamped_positions = torch.where(
                    exceeds_max_model_len.unsqueeze(0),
                    torch.zeros_like(positions),
                    positions,
                )
361
362
363
            else:
                positions += 1
                exceeds_max_model_len = positions >= self.max_model_len
364
                clamped_positions = torch.where(exceeds_max_model_len, 0, positions)
365
366

            # Increment the sequence lengths.
367
368
            common_attn_metadata.seq_lens += 1
            common_attn_metadata.seq_lens_cpu += 1
369
370
            # For the requests that exceed the max model length, we set the
            # sequence length to 1 to minimize their overheads in attention.
371

372
            common_attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1)
373

374
            common_attn_metadata.num_computed_tokens_cpu = (
375
                common_attn_metadata.seq_lens_cpu - 1
376
            )
377

378
            # Compute the slot mapping.
379
380
381
382
383
            if self.uses_mrope:
                # all dimensions of positions are the same
                block_numbers = clamped_positions[0] // self.block_size
            else:
                block_numbers = clamped_positions // self.block_size
384
            block_ids = common_attn_metadata.block_table_tensor.gather(
385
386
                dim=1, index=block_numbers.view(-1, 1)
            )
387
            block_ids = block_ids.view(-1)
388
389
            if self.uses_mrope:
                common_attn_metadata.slot_mapping = (
390
391
                    block_ids * self.block_size + clamped_positions[0] % self.block_size
                )
392
393
            else:
                common_attn_metadata.slot_mapping = (
394
395
                    block_ids * self.block_size + clamped_positions % self.block_size
                )
396
397
398
            # Mask out the slot mappings that exceed the max model length.
            # Otherwise, the KV cache will be inadvertently updated with the
            # padding tokens.
399
            common_attn_metadata.slot_mapping.masked_fill_(
400
401
                exceeds_max_model_len, PADDING_SLOT_ID
            )
402
403

            # Rebuild attention metadata
404
            attn_metadata = attn_metadata_builder.build_for_drafting(  # type: ignore
405
406
                common_attn_metadata=common_attn_metadata, draft_index=token_index + 1
            )
407
408
            for layer_name in self.attn_layer_names:
                per_layer_attn_metadata[layer_name] = attn_metadata
409

410
411
            # copy inputs to buffer for cudagraph
            self.input_ids[:batch_size] = input_ids
412
            self._set_positions(batch_size, clamped_positions)
413
            self.hidden_states[:batch_size] = hidden_states
414
            if self.supports_mm_inputs:
415
416
417
                self.inputs_embeds[:batch_size] = self.model.get_input_embeddings(
                    input_ids
                )
418

419
                input_ids = None
420
                inputs_embeds = self.inputs_embeds[:input_batch_size]
421
422
            else:
                input_ids = self.input_ids[:input_batch_size]
423
                inputs_embeds = None
424

425
            # Run the model.
426
427
428
            with set_forward_context(
                per_layer_attn_metadata, self.vllm_config, num_tokens=input_batch_size
            ):
429
                ret_hidden_states = self.model(
430
                    input_ids=input_ids,
431
                    positions=self._get_positions(input_batch_size),
432
433
                    hidden_states=self.hidden_states[:input_batch_size],
                    inputs_embeds=inputs_embeds,
434
                )
435
                if self.method == "mtp":
436
437
438
439
                    last_hidden_states = ret_hidden_states
                    hidden_states = ret_hidden_states
                else:
                    last_hidden_states, hidden_states = ret_hidden_states
440
            hidden_states = hidden_states[:batch_size]
441
            logits = self.model.compute_logits(last_hidden_states[:batch_size])
442
            draft_token_ids = logits.argmax(dim=-1)
443
444
445
446
            draft_token_ids_list.append(draft_token_ids)

        # [batch_size, num_speculative_tokens]
        draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
447
        return draft_token_ids
448

449
    def prepare_next_token_ids_cpu(
450
451
452
453
454
455
        self,
        sampled_token_ids: list[list[int]],
        requests: dict[str, CachedRequestState],
        gpu_input_batch: InputBatch,
        num_scheduled_tokens: dict[str, int],
    ) -> torch.Tensor:
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
        """
        This function is used to prepare the inputs for speculative decoding.
        It calculates the next token ids for each request based on the sampled
        token ids from the CPU. If a request has no sampled token ids (e.g.,
        during the initial decoding steps), it falls back to using the request
        state to get the next token id.
        """
        req_ids = gpu_input_batch.req_ids
        next_token_ids: list[int] = []
        for i, token_ids in enumerate(sampled_token_ids):
            if token_ids:
                # Common case.
                next_token_id = token_ids[-1]
            else:
                # Partial prefill (rare case).
                # Get the next token id from the request state.
                req_id = req_ids[i]
                req_state = requests[req_id]
474
                seq_len = req_state.num_computed_tokens + num_scheduled_tokens[req_id]
475
476
                next_token_id = req_state.get_token_id(seq_len)
            next_token_ids.append(next_token_id)
477
478
479
        next_token_ids = torch.tensor(
            next_token_ids, dtype=torch.int32, device=self.input_ids.device
        )
480
481
        return next_token_ids

482
483
484
485
486
487
488
489
490
    def prepare_next_token_ids_padded(
        self,
        common_attn_metadata: CommonAttentionMetadata,
        sampled_token_ids: torch.Tensor,
        requests: dict[str, CachedRequestState],
        gpu_input_batch: InputBatch,
        discard_request_indices: torch.Tensor,
        num_discarded_requests: int,
    ) -> tuple[torch.Tensor, torch.Tensor]:
491
492
493
494
495
496
497
498
499
500
501
502
503
        """
        This function is used to prepare the inputs for speculative decoding.
        It calculates the next token ids and the number of valid sampled tokens
        for each request, considering the "discarded" requests whose next token
        is not sampled and comes from `request.get_token_id()` instead.
        It also accounts for the rejected tokens in `sampled_token_ids`.
        This function must use device functions to operate on the inputs, and
        should not introduce any blocking CPU-GPU synchronization.
        """
        # TODO(Ben): Combine this into a custom fused kernel

        # Precompute get_token_id for when there is no valid next token
        num_reqs = gpu_input_batch.num_reqs
504
505
506
507
508
509
510
511
        self.backup_next_token_ids.np[:num_reqs] = np.array(
            [
                requests[gpu_input_batch.req_ids[i]].get_token_id(
                    common_attn_metadata.seq_lens_cpu[i].item()
                )
                for i in range(num_reqs)
            ]
        )
512
513
514
        self.backup_next_token_ids.copy_to_gpu(num_reqs)

        # Mask out the sampled tokens indices that should not be sampled.
515
516
517
        discard_sampled_tokens_req_indices = discard_request_indices[
            :num_discarded_requests
        ]
518
519
520

        valid_sampled_token_ids_gpu = sampled_token_ids.clone()
        valid_sampled_token_ids_gpu.index_fill_(
521
522
            0, discard_sampled_tokens_req_indices, -1
        )
523
524

        # Generate a mask for all valid tokens within those requests
525
526
527
        valid_mask = (valid_sampled_token_ids_gpu != -1) & (
            valid_sampled_token_ids_gpu < gpu_input_batch.vocab_size
        )
528
529
530
531
532
533
534
535
536
537
538

        # Count the number of valid tokens in each request
        valid_sampled_tokens_count = valid_mask.sum(dim=1)

        # Get the rightmost valid index per row
        last_valid_indices = valid_sampled_tokens_count - 1
        last_valid_indices_safe = torch.clamp(last_valid_indices, min=0)

        # Get last valid token from each row
        # (assume undefined state where there is no valid token)
        selected_tokens = torch.gather(
539
540
            valid_sampled_token_ids_gpu, 1, last_valid_indices_safe.unsqueeze(1)
        ).squeeze(1)
541
542
543
544

        # Use last token if valid, pre-computed backup if not
        batch_size = valid_sampled_token_ids_gpu.shape[0]
        next_token_ids = torch.where(
545
546
547
548
            last_valid_indices != -1,
            selected_tokens,
            self.backup_next_token_ids.gpu[:batch_size],
        )
549
550
551

        return next_token_ids, valid_sampled_tokens_count

552
553
554
555
556
557
    def prepare_inputs_padded(
        self,
        common_attn_metadata: CommonAttentionMetadata,
        spec_decode_metadata: SpecDecodeMetadata,
        valid_sampled_tokens_count: torch.Tensor,
    ) -> tuple[CommonAttentionMetadata, torch.Tensor, torch.Tensor]:
558
559
560
561
562
563
564
565
        """
        This function is used to prepare the inputs for speculative decoding
        It updates the common_attn_metadata for speculative decoding,
        but does not consider the rejected tokens. Instead, all tokens
        are included as inputs to the speculator, with the rejected tokens
        used as padding and filtered out later by `token_indices_to_sample`.
        No blocking CPU operations should be introduced in this function.
        """
566
567
568
569
570
571
572
        num_draft_tokens_gpu = torch.cat(
            [
                spec_decode_metadata.cu_num_draft_tokens[0:1],
                spec_decode_metadata.cu_num_draft_tokens[1:]
                - spec_decode_metadata.cu_num_draft_tokens[:-1],
            ]
        )
573
574
575
576

        num_rejected_tokens_gpu = torch.where(
            num_draft_tokens_gpu > 0,
            num_draft_tokens_gpu + 1 - valid_sampled_tokens_count,
577
578
            torch.zeros_like(num_draft_tokens_gpu),
        )
579
580
581

        query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu

582
        new_query_len_per_req = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
583
584
585
586
587
588
589
590
591

        total_num_tokens = query_start_loc_cpu[-1].item()
        token_indices = self.arange[:total_num_tokens]

        spec_common_attn_metadata = CommonAttentionMetadata(
            query_start_loc=common_attn_metadata.query_start_loc,
            seq_lens=common_attn_metadata.seq_lens,
            query_start_loc_cpu=query_start_loc_cpu,
            seq_lens_cpu=common_attn_metadata.seq_lens_cpu,
592
            num_computed_tokens_cpu=common_attn_metadata.num_computed_tokens_cpu,
593
594
595
596
597
598
599
            num_reqs=common_attn_metadata.num_reqs,
            num_actual_tokens=total_num_tokens,
            max_query_len=new_query_len_per_req.max().item(),
            max_seq_len=common_attn_metadata.seq_lens_cpu.max().item(),
            block_table_tensor=common_attn_metadata.block_table_tensor,
            slot_mapping=common_attn_metadata.slot_mapping[token_indices],
            causal=True,
600
            dcp_local_seq_lens=common_attn_metadata.dcp_local_seq_lens,
601
602
        )

603
604
605
        token_indices_to_sample = (
            common_attn_metadata.query_start_loc[1:] - 1 - num_rejected_tokens_gpu
        )
606
607
608

        return spec_common_attn_metadata, token_indices, token_indices_to_sample

609
610
611
612
613
614
615
616
617
618
619
    def propose_tree(
        self,
        batch_size: int,
        # [num_tokens, vocab_size]
        logits: torch.Tensor,
        # [num_tokens]
        positions: torch.Tensor,
        # [num_tokens, hidden_size]
        hidden_states: torch.Tensor,
        common_attn_metadata: CommonAttentionMetadata,
    ) -> list[torch.Tensor]:
620
621
622
623
        tree_attn_metadata_builder = self.runner.attn_groups[0][
            0
        ].get_metadata_builder()
        assert isinstance(tree_attn_metadata_builder, TreeAttentionMetadataBuilder)
624

625
        total_num_drafts = self.cu_drafts_per_level[0]
626
627
        level_num_drafts = total_num_drafts
        # Sample a draft token for each child at the tree root level.
628
        num_children = self.child_drafts_per_level[0]
629
630
631
        if num_children == 1:
            draft_token_ids = logits.argmax(dim=-1).view(batch_size, -1)
        else:
632
633
634
            draft_token_ids = torch.topk(logits, num_children, dim=-1).indices.view(
                batch_size, -1
            )
635
636
637
638
        draft_token_ids_list = [draft_token_ids]
        draft_hidden_states = hidden_states.view(batch_size, 1, -1)

        # Initialize empty tensors for concatenation with the level outputs.
639
640
641
642
643
644
645
646
647
        tree_input_ids = torch.empty(
            0, device=self.input_ids.device, dtype=self.input_ids.dtype
        )
        tree_positions = torch.empty(
            0, device=self.positions.device, dtype=self.positions.dtype
        )
        tree_hidden_states = torch.empty(
            0, device=self.hidden_states.device, dtype=self.hidden_states.dtype
        )
648
649
        # Precompute the draft token positions.
        flattened_draft_positions = (
650
651
            positions.view(batch_size, -1) + self.tree_draft_pos_offsets[:batch_size, :]
        )
652
        tree_depth = len(self.cu_drafts_per_level)
653
        for level in range(tree_depth - 1):
654
655
            # Get draft positions for RoPE.
            draft_positions = positions + (level + 1)
656
            exceeds_max_model_len = (positions + total_num_drafts) >= self.max_model_len
657
658
            # Mask out the position ids that exceed the max model length.
            # Otherwise, we may get out-of-range error in RoPE.
659
            draft_positions = torch.where(
660
661
662
                exceeds_max_model_len,
                0,
                draft_positions,
663
664
            ).view(batch_size, -1)

665
666
            if level_num_drafts > 1:
                # Repeat the positions for each draft at this level.
667
                draft_positions = draft_positions.repeat_interleave(
668
669
                    level_num_drafts, dim=1
                )
670
671
672
673

            if num_children > 1:
                # Repeat draft hidden states for each child.
                draft_hidden_states = draft_hidden_states.repeat_interleave(
674
675
                    num_children, dim=1
                )
676
677

            # Concatenate the draft tokens, positions, and hidden states.
678
679
            tree_input_ids = torch.cat([tree_input_ids, draft_token_ids], dim=1)
            tree_positions = torch.cat([tree_positions, draft_positions], dim=1)
680
            tree_hidden_states = torch.cat(
681
682
                [tree_hidden_states, draft_hidden_states], dim=1
            )
683
684
685

            # Build new attention metadata for the next level of drafts.
            # This is necessary to support tree attention.
686
            query_len = total_num_drafts
687
688
            common_attn_metadata = replace(
                common_attn_metadata,
689
                query_start_loc=query_len * self.arange[: batch_size + 1],
690
691
692
693
694
695
                seq_lens=common_attn_metadata.seq_lens + level_num_drafts,
                num_actual_tokens=batch_size * query_len,
                max_query_len=query_len,
            )
            attn_metadata = tree_attn_metadata_builder.build_for_drafting(
                common_attn_metadata=common_attn_metadata,
696
                draft_index=level + 1,
697
698
699
700
701
702
703
704
            )

            # Apply new attention metadata to all layers.
            per_layer_attn_metadata = {}
            for layer_name in self.attn_layer_names:
                per_layer_attn_metadata[layer_name] = attn_metadata

            # Consider max model length.
705
706
707
            attn_metadata.max_seq_len = min(
                attn_metadata.max_seq_len, self.max_model_len
            )
708
709
710
711
712
            # For the requests that exceed the max model length, we set the
            # sequence length to 1 to minimize their overheads in attention.
            attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1)

            # Compute the slot mapping.
713
            query_positions = flattened_draft_positions[:, level : level + query_len]
714
            block_numbers = query_positions // self.block_size
715
716
717
718
            block_ids = attn_metadata.block_table.gather(dim=1, index=block_numbers)
            slot_mapping = (
                block_ids * self.block_size + query_positions % self.block_size
            )
719
720
721
722
723
724
725
726
727
728
729
            # Mask out the slot mappings that exceed the max model length.
            # Otherwise, the KV cache will be inadvertently updated with the
            # padding tokens.
            slot_mapping[exceeds_max_model_len] = PADDING_SLOT_ID
            attn_metadata.slot_mapping = slot_mapping.view(-1)

            # Copy inputs to buffer for cudagraph.
            num_tokens = attn_metadata.num_actual_tokens
            input_ids = tree_input_ids.view(-1)
            self.input_ids[:num_tokens] = input_ids
            self.positions[:num_tokens] = tree_positions.view(-1)
730
            self.hidden_states[:num_tokens] = tree_hidden_states.view(num_tokens, -1)
731

732
733
            if self.use_cuda_graph and num_tokens <= self.cudagraph_batch_sizes[-1]:
                num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
734
735
736
            else:
                num_input_tokens = num_tokens
            # Run the model.
737
738
739
            with set_forward_context(
                per_layer_attn_metadata, self.vllm_config, num_tokens=num_input_tokens
            ):
740
741
742
743
744
745
746
747
748
                last_hidden_states, hidden_states = self.model(
                    input_ids=self.input_ids[:num_input_tokens],
                    positions=self.positions[:num_input_tokens],
                    hidden_states=self.hidden_states[:num_input_tokens],
                    inputs_embeds=None,
                )

            # Get the output hidden states for the draft tokens.
            draft_hidden_states = hidden_states[:num_tokens].view(
749
750
                batch_size, query_len, -1
            )[:, -level_num_drafts:]
751
            draft_last_hidden_states = last_hidden_states[:num_tokens].view(
752
753
                batch_size, query_len, -1
            )[:, -level_num_drafts:]
754
755
756

            # Get the output logits for the draft tokens.
            logits = self.model.compute_logits(
757
758
                draft_last_hidden_states.reshape(batch_size * level_num_drafts, -1)
            )
759
760
761
762
763
764

            # Sample a draft token for each child at the next tree level.
            num_children = self.child_drafts_per_level[level + 1]
            if num_children == 1:
                draft_token_ids = logits.argmax(dim=-1).view(batch_size, -1)
            else:
765
766
767
                draft_token_ids = torch.topk(logits, num_children, dim=-1).indices.view(
                    batch_size, -1
                )
768
769
770
            draft_token_ids_list.append(draft_token_ids)

            # Update the # drafts counters for the next tree level.
771
            level_num_drafts = self.cu_drafts_per_level[level + 1] - total_num_drafts
772
773
774
            total_num_drafts = self.cu_drafts_per_level[level + 1]
        return draft_token_ids_list

775
    def prepare_inputs(
776
777
        self,
        common_attn_metadata: CommonAttentionMetadata,
778
779
        sampled_token_ids: list[list[int]],
        num_draft_tokens: list[int],
780
781
    ) -> tuple[CommonAttentionMetadata, torch.Tensor]:
        """
782
        This function is used to prepare the inputs for speculative decoding.
783
784
785
786
787
788
        It updates to the common_attn_metadata to account for the rejected
        tokens (and newly sampled tokens). It also returns the token indices
        of the tokens that should be fed to the speculator.
        """
        # E.g.
        #  common_attn_metadata.query_start_loc{_cpu}:
789
        #       [0, q1, q1 + q2, q1 + q2 + q3]
790
791
792
793
794
795
        #  common_attn_metadata.seq_lens{_cpu}: [s1, s2, s3]
        #  num_rejected_tokens: [n1, n2, n3]
        # This function computes the intermediate values:
        #  num_tokens_per_req: [q1 - n1, q2 - n2, q3 - n3]
        # And returns:
        #  common_attn_metadata.query_start_loc{_cpu}:
796
        #       [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
797
        #  common_attn_metadata.seq_lens{_cpu}:
798
        #       [s1 - n1 + 1, s2 - n2 + 1, s3 - n3 + 1]
799
        #  token_indices: [0, 1, ..., q1 - n1 - 1,
800
801
        #                 q1, q1 + 1, ..., q1 + q2 - n2 - 1,
        #                 q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1]
802

803
804
805
806
        num_rejected_tokens = [
            n + 1 - len(sampled_token_ids[i]) if n > 0 else 0
            for i, n in enumerate(num_draft_tokens)
        ]
807
        num_rejected_tokens = torch.tensor(num_rejected_tokens, dtype=torch.int32)
808

809
810
        device = common_attn_metadata.query_start_loc.device
        query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
811
        new_seq_lens_cpu = common_attn_metadata.seq_lens_cpu - num_rejected_tokens
812
813

        # [0, q1, q1 + q2, q1 + q2 + q3] -> [q1, q2, q3]
814
        new_query_len_per_req = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
815
816
817
818
819
820
821
822
        # [q1, q2, q3] -> [q1 - n1, q2 - n2, q3 - n3]
        new_num_tokens_per_req = new_query_len_per_req - num_rejected_tokens
        new_num_tokens_per_req_np = new_num_tokens_per_req.numpy()

        # [q1 - n1, q2 - n2, q3 - n3] ->
        # [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
        new_query_start_loc_cpu = torch.zeros(
            query_start_loc_cpu.shape,
823
            dtype=torch.int32,
824
825
            pin_memory=is_pin_memory_available(),
        )
826
827
828
829
830
831
832
833
834
        new_query_start_loc_np = new_query_start_loc_cpu.numpy()
        np.cumsum(new_num_tokens_per_req_np, out=new_query_start_loc_np[1:])

        total_num_tokens = new_query_start_loc_np[-1]
        # Example assuming num_tokens_per_req_np = [2, 4, 3]
        # this implies that `new_query_start_locs` is:
        # [0, 2, 6, 9] ->
        # [0, 0, 2, 2, 2, 2, 6, 6, 6]
        #  _r1_  ____r2____  ___r3__
835
836
837
        new_query_start_locs_expanded = np.repeat(
            new_query_start_loc_np[:-1], new_num_tokens_per_req_np
        )
838
839
840
        # [0, 1, 2, 3, 4, 5, 6, 7, 8] ->
        # [0, 1, 0, 1, 2, 3, 0, 1, 2]
        #  _r1_  ____r2____  ___r3__
841
842
843
        token_offests = (
            self.token_arange_np[:total_num_tokens] - new_query_start_locs_expanded
        )
844
845
846
847
848
849

        # Expand starting positions to match token pattern
        # [0, q1, q1 + q2] ->
        # [0, 0, q1, q1, q1, q1, q1 + q2, q1 + q2, q1 + q2]
        #  _r1_  _____r2_______  ___________r3____________
        old_query_start_locs_expanded = np.repeat(
850
851
            query_start_loc_cpu[:-1].numpy(), new_num_tokens_per_req_np
        )
852
        # Final token indices are:
853
854
855
        # [0, 1,                                // req 1
        #  q1 + 0, q1 + 1, q1 + 2, q1 + 3,       // req 2
        #  q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3
856
        token_indices_np = token_offests + old_query_start_locs_expanded
857
        token_indices = torch.from_numpy(token_indices_np).to(device, non_blocking=True)
858
859

        spec_common_attn_metadata = CommonAttentionMetadata(
860
            query_start_loc=new_query_start_loc_cpu.to(device, non_blocking=True),
861
862
863
            seq_lens=new_seq_lens_cpu.to(device, non_blocking=True),
            query_start_loc_cpu=new_query_start_loc_cpu,
            seq_lens_cpu=new_seq_lens_cpu,
864
            num_computed_tokens_cpu=common_attn_metadata.num_computed_tokens_cpu,
865
866
867
            num_reqs=common_attn_metadata.num_reqs,
            num_actual_tokens=total_num_tokens,
            max_query_len=new_query_len_per_req.max().item(),
868
            max_seq_len=new_seq_lens_cpu.max().item(),
869
870
            block_table_tensor=common_attn_metadata.block_table_tensor,
            slot_mapping=common_attn_metadata.slot_mapping[token_indices],
871
            causal=True,
872
            dcp_local_seq_lens=common_attn_metadata.dcp_local_seq_lens,
873
        )
874
875

        return spec_common_attn_metadata, token_indices
876

877
    def get_model_name(self, model: nn.Module) -> str:
878
        if hasattr(model, "module"):  # multi-GPU
879
880
881
            model = model.module
        return model.__class__.__name__

882
    def load_model(self, target_model: nn.Module) -> None:
883
        draft_model_config = self.vllm_config.speculative_config.draft_model_config
884
        target_attn_layer_names = set(
885
            get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase).keys()
886
        )
887
888
        # FIXME: support hybrid kv for draft model
        target_indexer_layer_names = set(
889
890
891
892
            get_layers_from_vllm_config(
                self.vllm_config, DeepseekV32IndexerCache
            ).keys()
        )
893

894
        from vllm.compilation.backends import set_model_tag
895

896
        with set_model_tag("eagle_head"):
897
898
899
            self.model = get_model(
                vllm_config=self.vllm_config, model_config=draft_model_config
            )
900

901
        draft_attn_layer_names = (
902
            get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase).keys()
903
904
905
906
907
908
            - target_attn_layer_names
        )
        indexer_layers = get_layers_from_vllm_config(
            self.vllm_config, DeepseekV32IndexerCache
        )
        draft_indexer_layer_names = indexer_layers.keys() - target_indexer_layer_names
909
        self.attn_layer_names = list(draft_attn_layer_names)
910
911
912
913
914
        self.indexer_layer_names = list(draft_indexer_layer_names)

        if self.indexer_layer_names:
            first_layer = self.indexer_layer_names[0]
            self.draft_indexer_metadata_builder = (
915
916
917
                indexer_layers[first_layer]
                .get_attn_backend()
                .get_builder_cls()(
918
919
920
921
                    indexer_layers[first_layer].get_kv_cache_spec(),
                    self.indexer_layer_names,
                    self.vllm_config,
                    self.device,
922
923
                )
            )
924
925
        else:
            self.draft_indexer_metadata_builder = None
926

927
        if self.supports_mm_inputs:
928
929
930
            # Even if the target model is multimodal, we can also use
            # text-only draft models
            try:
931
932
933
934
                dummy_input_ids = torch.tensor([[1]], device=self.input_ids.device)
                self.model.get_input_embeddings(
                    dummy_input_ids, multimodal_embeddings=None
                )
935
936
937
            except (NotImplementedError, AttributeError, TypeError):
                logger.warning(
                    "Draft model does not support multimodal inputs, "
938
939
                    "falling back to text-only mode"
                )
940
                self.supports_mm_inputs = False
941

942
943
        if supports_multimodal(target_model):
            # handle multimodality
944
945
946
947
948
            if (
                self.get_model_name(target_model)
                == "Qwen2_5_VLForConditionalGeneration"
            ):
                self.model.config.image_token_index = target_model.config.image_token_id
949
950
            else:
                self.model.config.image_token_index = (
951
952
                    target_model.config.image_token_index
                )
953
954
955
            target_language_model = target_model.get_language_model()
        else:
            target_language_model = target_model
956
        # share embed_tokens with the target model if needed
957
        if get_pp_group().world_size == 1:
958
            if hasattr(target_language_model.model, "embed_tokens"):
959
                target_embed_tokens = target_language_model.model.embed_tokens
960
            elif hasattr(target_language_model.model, "embedding"):
961
962
963
                target_embed_tokens = target_language_model.model.embedding
            else:
                raise AttributeError(
964
965
                    "Target model does not have 'embed_tokens' or 'embedding' attribute"
                )
966
967
968
969
970
971
972

            # Check if shapes match and we found the embedding
            eagle_shape = self.model.model.embed_tokens.weight.shape
            target_shape = target_embed_tokens.weight.shape
            if eagle_shape == target_shape:
                logger.info(
                    "Assuming the EAGLE head shares the same vocab embedding"
973
974
                    " with the target model."
                )
975
976
977
978
979
                del self.model.model.embed_tokens
                self.model.model.embed_tokens = target_embed_tokens
            else:
                logger.info(
                    "The EAGLE head's vocab embedding will be loaded separately"
980
981
                    " from the target model."
                )
982
        else:
983
            logger.info(
984
                "The EAGLE head's vocab embedding will be loaded separately"
985
986
                " from the target model."
            )
987
988
989
990

        # share lm_head with the target model if needed
        # some model definition do not define lm_head explicitly
        # and reuse embed_tokens for lm_head, e.g., CohereForCausalLM
991
992
        if self.vllm_config.speculative_config.method != "eagle3":
            if hasattr(target_language_model, "lm_head"):
993
                logger.info("Loading EAGLE LM head weights from the target model.")
994
995
                self.model.lm_head = target_language_model.lm_head
        else:
996
997
998
999
1000
1001
1002
1003
1004
1005
            if (
                hasattr(self.model, "lm_head")
                and hasattr(target_language_model, "lm_head")
                and self.model.lm_head.weight.shape
                == target_language_model.lm_head.weight.shape
            ):
                logger.info(
                    "Assuming the EAGLE head shares the same lm_head"
                    " with the target model."
                )
1006
1007
1008
1009
1010
                del self.model.lm_head
                self.model.lm_head = target_language_model.lm_head
            else:
                logger.info(
                    "The EAGLE head's lm_head will be loaded separately"
1011
1012
                    " from the target model."
                )
1013

1014
1015
1016
1017
1018
    @torch.inference_mode()
    def dummy_run(
        self,
        num_tokens: int,
    ) -> None:
1019
        with set_forward_context(None, self.vllm_config, num_tokens=num_tokens):
1020
            if self.supports_mm_inputs:
1021
1022
1023
1024
1025
1026
                input_ids = None
                inputs_embeds = self.inputs_embeds[:num_tokens]
            else:
                input_ids = self.input_ids[:num_tokens]
                inputs_embeds = None

1027
            self.model(
1028
                input_ids=input_ids,
1029
                positions=self._get_positions(num_tokens),
1030
1031
                hidden_states=self.hidden_states[:num_tokens],
                inputs_embeds=inputs_embeds,
1032
            )
1033

1034
    def _get_attention_metadata_builder(self) -> list[AttentionMetadataBuilder]:
1035
        """Find and return the attention metadata builders for EAGLE layers.
1036

1037
1038
        Returns:
            The metadata builders for EAGLE layers.
1039

1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
        Raises:
            AssertionError: If no metadata builders are found for EAGLE layers.
        """
        builder = None
        chosen_layer = self.attn_layer_names[0]

        for kv_cache_group in self.runner.attn_groups:
            for attn_group in kv_cache_group:
                if chosen_layer in attn_group.layer_names:
                    builder = attn_group.get_metadata_builder()
                    break
            if builder is not None:
                break

        assert builder is not None, (
1055
1056
            "Failed to find attention metadata builder for EAGLE layers."
        )
1057
1058
        return builder

1059
    def validate_same_kv_cache_group(self, kv_cache_config: KVCacheConfig) -> None:
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
        """
        Validate that all eagle layers belong to the same KVCacheGroup.
        Need this assumption to ensure all eagle layers can use the
        same AttentionMetadata.
        May extend to multiple AttentionMetadata in the future.
        """
        kv_cache_groups: dict[str, int] = {}
        for id, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups):
            for layer_name in kv_cache_group.layer_names:
                kv_cache_groups[layer_name] = id
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
        assert (
            len(
                set(
                    [
                        kv_cache_groups[layer_name]
                        for layer_name in self.attn_layer_names
                    ]
                )
            )
            == 1
        ), "All eagle layers should belong to the same kv cache group"
1081

1082

1083
1084
1085
1086
# NOTE(woosuk): Currently, the below code is not used and we always use argmax
# to sample the draft tokens. We will use this after we find a way to manage
# the draft prob tensor.
# Refer to https://github.com/vllm-project/vllm/pull/16899 for the details.
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
# FIXME(woosuk): The logic here is duplicated with the main sampling code.
# We should refactor this to reuse the same sampling implementation.
def compute_probs_and_sample_next_token(
    logits: torch.Tensor,
    sampling_metadata: SamplingMetadata,
) -> tuple[torch.Tensor, torch.Tensor]:
    if sampling_metadata.all_greedy:
        # For greedy requests, draft_probs is not used in rejection sampling.
        # Therefore, we can just return the logits.
        probs = logits
        next_token_ids = logits.argmax(dim=-1)
        return next_token_ids, probs

    is_greedy = sampling_metadata.temperature == -1
    temperature = torch.where(is_greedy, 1.0, sampling_metadata.temperature)
    logits.div_(temperature.view(-1, 1))
    probs = logits.softmax(dim=-1, dtype=torch.float32)

    # NOTE(woosuk): Currently, we ignore most of the sampling parameters in
    # generating the draft tokens. We only use the temperature. While this
    # could degrade the acceptance rate, it does not affect the distribution
    # of the generated tokens after rejection sampling.

    # TODO(woosuk): Consider seeds.
    q = torch.empty_like(probs)
    q.exponential_()
1113
1114
1115
    # NOTE(woosuk): We shouldn't use `probs.div_(q)` because the draft_probs
    # will be used later for rejection sampling.
    next_token_ids = probs.div(q).argmax(dim=-1).view(-1)
1116
1117
1118
1119
1120
1121
1122
1123
    if not sampling_metadata.all_random:
        greedy_token_ids = probs.argmax(dim=-1)
        next_token_ids = torch.where(
            is_greedy,
            greedy_token_ids,
            next_token_ids,
        )
    return next_token_ids, probs