eagle.py 49.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
import ast
from dataclasses import replace
5
from importlib.util import find_spec
6
from typing import Optional
7

8
import numpy as np
9
10
11
import torch
import torch.nn as nn

12
13
from vllm.attention.layer import Attention
from vllm.config import (CompilationLevel, VllmConfig,
14
                         get_layers_from_vllm_config)
15
from vllm.distributed.parallel_state import get_pp_group
16
from vllm.forward_context import set_forward_context
17
from vllm.logger import init_logger
18
from vllm.model_executor.model_loader import get_model
19
from vllm.model_executor.models import supports_multimodal
20
from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache
21
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
22
from vllm.multimodal import MULTIMODAL_REGISTRY
23
from vllm.platforms import current_platform
24
from vllm.utils import is_pin_memory_available
25
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
26
27
from vllm.v1.attention.backends.tree_attn import (TreeAttentionMetadata,
                                                  TreeAttentionMetadataBuilder)
28
from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata
29
30
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
                                              CommonAttentionMetadata)
31
from vllm.v1.kv_cache_interface import KVCacheConfig
32
from vllm.v1.sample.metadata import SamplingMetadata
33
34
35
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.utils import CpuGpuBuffer
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
36
from vllm.v1.worker.ubatching import dbo_current_ubatch_id
37

38
39
logger = init_logger(__name__)

40
41
PADDING_SLOT_ID = -1

42
43
44
45
46
47
48

class EagleProposer:

    def __init__(
        self,
        vllm_config: VllmConfig,
        device: torch.device,
Jiayi Yao's avatar
Jiayi Yao committed
49
        runner=None,
50
51
    ):
        self.vllm_config = vllm_config
52
        self.speculative_config = vllm_config.speculative_config
53
        assert self.speculative_config is not None
54
55
        self.draft_model_config = self.speculative_config.draft_model_config
        self.method = self.speculative_config.method
56

Jiayi Yao's avatar
Jiayi Yao committed
57
        self.runner = runner
58
        self.device = device
59
        self.dtype = vllm_config.model_config.dtype
60
61
62
63
64
65
        self.max_model_len = vllm_config.model_config.max_model_len
        self.block_size = vllm_config.cache_config.block_size
        self.num_speculative_tokens = (
            self.speculative_config.num_speculative_tokens)
        self.max_num_tokens = (
            vllm_config.scheduler_config.max_num_batched_tokens)
66
        self.token_arange_np = np.arange(self.max_num_tokens)
67
68
69
70
        # We need to get the hidden size from the draft model config because
        # the draft model's hidden size can be different from the target model's
        # hidden size (e.g., Llama 3.3 70B).
        self.hidden_size = self.draft_model_config.get_hidden_size()
71

72
73
74
75
        # Multi-modal data support
        self.mm_registry = MULTIMODAL_REGISTRY
        self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
            vllm_config.model_config)
76

77
        self.attn_metadata_builder: Optional[AttentionMetadataBuilder] = None
78
79
80
81
        self.draft_indexer_metadata_builder: Optional[
            AttentionMetadataBuilder] = None
        self.attn_layer_names: list[str] = []
        self.indexer_layer_names: list[str] = []
82

83
84
        self.use_cuda_graph = (not current_platform.is_xpu()
                               and self.vllm_config.compilation_config.level
85
                               == CompilationLevel.PIECEWISE and
86
87
                               not self.vllm_config.model_config.enforce_eager
                               and not self.speculative_config.enforce_eager)
88
        self.cudagraph_batch_sizes = list(
89
90
            reversed(self.vllm_config.compilation_config.
                     cudagraph_capture_sizes)) if self.use_cuda_graph else []
91
92
93
94
95

        # persistent buffers for cuda graph
        self.input_ids = torch.zeros(self.max_num_tokens,
                                     dtype=torch.int32,
                                     device=device)
96
97
98
99
100
101
102
103
104
105
106
        self.uses_mrope = self.vllm_config.model_config.uses_mrope
        if self.uses_mrope:
            # M-RoPE need (3, max_num_tokens)
            self.mrope_positions = torch.zeros((3, self.max_num_tokens),
                                               dtype=torch.int64,
                                               device=device)
        else:
            # RoPE need (max_num_tokens,)
            self.positions = torch.zeros(self.max_num_tokens,
                                         dtype=torch.int64,
                                         device=device)
107
108
109
110
        self.hidden_states = torch.zeros(
            (self.max_num_tokens, self.hidden_size),
            dtype=self.dtype,
            device=device)
111

112
113
        # We need +1 here because the arange is used to set query_start_loc,
        # which has one more element than batch_size.
114
        max_batch_size = vllm_config.scheduler_config.max_num_seqs
115
116
117
118
        max_num_slots_for_arange = max(max_batch_size + 1, self.max_num_tokens)
        self.arange = torch.arange(max_num_slots_for_arange,
                                   device=device,
                                   dtype=torch.int32)
119

120
121
122
123
124
        self.inputs_embeds = torch.zeros(
            (self.max_num_tokens, self.hidden_size),
            dtype=self.dtype,
            device=device)

125
126
127
128
129
130
131
        self.backup_next_token_ids = CpuGpuBuffer(
            max_batch_size,
            dtype=torch.int32,
            pin_memory=is_pin_memory_available(),
            device=device,
            with_numpy=True)

132
        # Determine allowed attention backends once during initialization.
133
        self.allowed_attn_types: Optional[tuple] = None
134
135
136
137
138
139
140
141
142
        if current_platform.is_rocm():
            rocm_types = [TritonAttentionMetadata, FlashAttentionMetadata]
            # vllm.v1.attention.backends.rocm_aiter_fa is an optional backend
            if find_spec("vllm.v1.attention.backends.rocm_aiter_fa"):
                from vllm.v1.attention.backends.rocm_aiter_fa import (
                    AiterFlashAttentionMetadata)
                rocm_types.append(AiterFlashAttentionMetadata)
            self.allowed_attn_types = tuple(rocm_types)

143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
        # Parse the speculative token tree.
        spec_token_tree = self.speculative_config.speculative_token_tree
        self.tree_choices: list[tuple[int,
                                      ...]] = ast.literal_eval(spec_token_tree)
        tree_depth = len(self.tree_choices[-1])
        # Precompute per-level properties of the tree.
        num_drafts_per_level = [0] * tree_depth
        for node in self.tree_choices:
            num_drafts_per_level[len(node) - 1] += 1
        self.cu_drafts_per_level = [num_drafts_per_level[0]]
        self.child_drafts_per_level = [num_drafts_per_level[0]]
        for level in range(1, tree_depth):
            self.cu_drafts_per_level.append(self.cu_drafts_per_level[-1] +
                                            num_drafts_per_level[level])
            self.child_drafts_per_level.append(num_drafts_per_level[level] //
                                               num_drafts_per_level[level - 1])
        # Precompute draft position offsets in flattened tree.
        self.tree_draft_pos_offsets = torch.arange(
            1,
            len(self.tree_choices) + 1,
            device=device,
            dtype=torch.int32,
        ).repeat(max_batch_size, 1)

167
168
169
170
171
172
173
174
175
176
177
    def _get_positions(self, num_tokens: int):
        if self.uses_mrope:
            return self.mrope_positions[:, :num_tokens]
        return self.positions[:num_tokens]

    def _set_positions(self, num_tokens: int, positions: torch.Tensor):
        if self.uses_mrope:
            self.mrope_positions[:, :num_tokens] = positions
        else:
            self.positions[:num_tokens] = positions

178
179
180
181
    def propose(
        self,
        # [num_tokens]
        target_token_ids: torch.Tensor,
182
        # [num_tokens] or [3, num_tokens] when M-RoPE is enabled
183
184
185
186
187
        target_positions: torch.Tensor,
        # [num_tokens, hidden_size]
        target_hidden_states: torch.Tensor,
        # [batch_size]
        next_token_ids: torch.Tensor,
188
        last_token_indices: Optional[torch.Tensor],
189
        common_attn_metadata: CommonAttentionMetadata,
190
        sampling_metadata: SamplingMetadata,
191
192
        mm_embed_inputs: Optional[tuple[list[torch.Tensor],
                                        torch.Tensor]] = None,
193
    ) -> torch.Tensor:
194
195
        num_tokens = target_token_ids.shape[0]
        batch_size = next_token_ids.shape[0]
196
197
198

        if last_token_indices is None:
            last_token_indices = common_attn_metadata.query_start_loc[1:] - 1
199

200
201
202
203
204
        if self.method == "eagle3":
            assert isinstance(self.model, Eagle3LlamaForCausalLM)
            target_hidden_states = self.model.combine_hidden_states(
                target_hidden_states)
            assert target_hidden_states.shape[-1] == self.hidden_size
205
206
        # Shift the input ids by one token.
        # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
207
        self.input_ids[:num_tokens - 1] = target_token_ids[1:]
208
209
        # Replace the last token with the next token.
        # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
210
        self.input_ids[last_token_indices] = next_token_ids
211

212
        assert self.runner is not None
Jiayi Yao's avatar
Jiayi Yao committed
213

214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
        # FIXME: need to consider multiple kv_cache_groups
        ubatch_id = dbo_current_ubatch_id()
        attn_metadata_builder = \
            self.runner.attn_groups[0][0].metadata_builders[ubatch_id]
        attn_metadata = attn_metadata_builder.build_for_drafting(
            common_attn_metadata=common_attn_metadata, draft_index=0)
        # FIXME: support hybrid kv for draft model (remove separate indexer)
        if self.draft_indexer_metadata_builder:
            draft_indexer_metadata = (
                self.draft_indexer_metadata_builder.build_for_drafting(
                    common_attn_metadata=common_attn_metadata,
                    draft_index=0,
                ))
        else:
            draft_indexer_metadata = None
229
230
231
232
233
        # At this moment, we assume all eagle layers belong to the same KV
        # cache group, thus using the same attention metadata.
        per_layer_attn_metadata = {}
        for layer_name in self.attn_layer_names:
            per_layer_attn_metadata[layer_name] = attn_metadata
234
235
236
237
        for layer_name in self.indexer_layer_names:
            assert draft_indexer_metadata is not None
            per_layer_attn_metadata[layer_name] = draft_indexer_metadata

238
        if self.use_cuda_graph and \
239
                num_tokens <= self.cudagraph_batch_sizes[-1]:
240
241
242
243
            num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
        else:
            num_input_tokens = num_tokens
        # copy inputs to buffer for cudagraph
244
        self._set_positions(num_tokens, target_positions)
245
        self.hidden_states[:num_tokens] = target_hidden_states
246
247
248
249
250
251
252
253

        if self.supports_mm_inputs:
            mm_embeds, is_mm_embed = mm_embed_inputs or (None, None)

            self.inputs_embeds[:num_tokens] = self.model.get_input_embeddings(
                self.input_ids[:num_tokens],
                multimodal_embeddings=mm_embeds,
                is_multimodal=is_mm_embed,
254
            )
255

256
            input_ids = None
257
            inputs_embeds = self.inputs_embeds[:num_input_tokens]
258
259
        else:
            input_ids = self.input_ids[:num_input_tokens]
260
            inputs_embeds = None
261

262
        with set_forward_context(per_layer_attn_metadata,
263
264
                                 self.vllm_config,
                                 num_tokens=num_input_tokens):
Jiayi Yao's avatar
Jiayi Yao committed
265
            ret_hidden_states = self.model(
266
                input_ids=input_ids,
267
                positions=self._get_positions(num_input_tokens),
268
269
                hidden_states=self.hidden_states[:num_input_tokens],
                inputs_embeds=inputs_embeds,
270
            )
271
            if self.method == "mtp":
Jiayi Yao's avatar
Jiayi Yao committed
272
                last_hidden_states = ret_hidden_states
273
                hidden_states = last_hidden_states
Jiayi Yao's avatar
Jiayi Yao committed
274
275
            else:
                last_hidden_states, hidden_states = ret_hidden_states
276
        sample_hidden_states = last_hidden_states[last_token_indices]
277
        logits = self.model.compute_logits(sample_hidden_states)
278
279
280
281
282
283

        # Early exit if there is only one draft token to be generated.
        if self.num_speculative_tokens == 1:
            draft_token_ids = logits.argmax(dim=-1)
            return draft_token_ids.view(-1, 1)

284
285
286
287
        if self.uses_mrope:
            positions = target_positions[:, last_token_indices]
        else:
            positions = target_positions[last_token_indices]
XuruiYang's avatar
XuruiYang committed
288
289
290
291
        if self.method in ("deepseek_mtp", "ernie_mtp", "longcat_flash_mtp"):
            hidden_states = self.hidden_states[last_token_indices]
        else:
            hidden_states = hidden_states[last_token_indices]
292
293
294

        if isinstance(attn_metadata, TreeAttentionMetadata):
            # Draft using tree attention.
295
296
297
298
299
300
301
302
303
304
            draft_token_ids_list = self.propose_tree(
                batch_size=batch_size,
                logits=logits,
                positions=positions,
                hidden_states=hidden_states,
                common_attn_metadata=common_attn_metadata,
            )
            # [batch_size, num_tree_tokens]
            return torch.cat(draft_token_ids_list, dim=1)

305
        draft_token_ids = logits.argmax(dim=-1)
306

307
308
        if self.allowed_attn_types is not None and \
            not isinstance(attn_metadata, self.allowed_attn_types):
309
310
311
312
313
            raise ValueError(
                f"Unsupported attention metadata type for speculative "
                "decoding with num_speculative_tokens > 1: "
                f"{type(attn_metadata)}. Supported types are: "
                f"{self.allowed_attn_types}")
314

315
316
317
        # Generate the remaining draft tokens.
        draft_token_ids_list = [draft_token_ids]

318
        if self.use_cuda_graph and \
319
                batch_size <= self.cudagraph_batch_sizes[-1]:
320
321
322
            input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size)
        else:
            input_batch_size = batch_size
323
324
325
326
327
328
329

        common_attn_metadata.num_actual_tokens = batch_size
        common_attn_metadata.max_query_len = 1
        common_attn_metadata.query_start_loc = self.arange[:batch_size + 1]
        common_attn_metadata.query_start_loc_cpu = torch.from_numpy(
            self.token_arange_np[:batch_size + 1]).clone()
        for token_index in range(self.num_speculative_tokens - 1):
330
            # Update the inputs.
331
332
333
            # cast to int32 is crucial when eagle model is compiled.
            # tensor.argmax() returns int64 by default.
            input_ids = draft_token_ids_list[-1].int()
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
            if self.uses_mrope:
                positions += 1
                # NOTE(woosuk): We should handle the case where the draft model
                # generates tokens beyond the max model length.
                # Since it is complex to remove such requests from the batch,
                # we keep them in the batch but adjust the position ids
                # and slot mappings to avoid the
                # out-of-range access during the model execution.
                # The draft tokens generated with this adjustment
                # should be ignored.
                exceeds_max_model_len = positions[0] >= self.max_model_len
                # Mask out the position ids that exceed the max model length.
                # Otherwise, we may get out-of-range error in RoPE.
                clamped_positions = torch.where\
                    (exceeds_max_model_len.unsqueeze(0), \
                     torch.zeros_like(positions), positions)
            else:
                positions += 1
                exceeds_max_model_len = positions >= self.max_model_len
                clamped_positions = torch.where(exceeds_max_model_len, 0,
                                                positions)
355
356

            # Increment the sequence lengths.
357
358
            common_attn_metadata.seq_lens += 1
            common_attn_metadata.seq_lens_cpu += 1
359
360
            # For the requests that exceed the max model length, we set the
            # sequence length to 1 to minimize their overheads in attention.
361

362
363
364
365
366
            common_attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len,
                                                       1)

            common_attn_metadata.num_computed_tokens_cpu = \
                common_attn_metadata.seq_lens_cpu - 1
367

368
            # Compute the slot mapping.
369
370
371
372
373
            if self.uses_mrope:
                # all dimensions of positions are the same
                block_numbers = clamped_positions[0] // self.block_size
            else:
                block_numbers = clamped_positions // self.block_size
374
            block_ids = common_attn_metadata.block_table_tensor.gather(
375
                dim=1, index=block_numbers.view(-1, 1))
376
            block_ids = block_ids.view(-1)
377
378
379
380
381
382
383
384
            if self.uses_mrope:
                common_attn_metadata.slot_mapping = (
                    block_ids * self.block_size +
                    clamped_positions[0] % self.block_size)
            else:
                common_attn_metadata.slot_mapping = (
                    block_ids * self.block_size +
                    clamped_positions % self.block_size)
385
386
387
            # Mask out the slot mappings that exceed the max model length.
            # Otherwise, the KV cache will be inadvertently updated with the
            # padding tokens.
388
389
390
391
            common_attn_metadata.slot_mapping.masked_fill_(
                exceeds_max_model_len, PADDING_SLOT_ID)

            # Rebuild attention metadata
392
            attn_metadata = attn_metadata_builder.build_for_drafting(  # type: ignore
393
394
                common_attn_metadata=common_attn_metadata,
                draft_index=token_index + 1)
395
396
            for layer_name in self.attn_layer_names:
                per_layer_attn_metadata[layer_name] = attn_metadata
397

398
399
            # copy inputs to buffer for cudagraph
            self.input_ids[:batch_size] = input_ids
400
            self._set_positions(batch_size, clamped_positions)
401
            self.hidden_states[:batch_size] = hidden_states
402
403
404
405
            if self.supports_mm_inputs:
                self.inputs_embeds[:batch_size] = \
                    self.model.get_input_embeddings(input_ids)

406
                input_ids = None
407
                inputs_embeds = self.inputs_embeds[:input_batch_size]
408
409
            else:
                input_ids = self.input_ids[:input_batch_size]
410
                inputs_embeds = None
411

412
            # Run the model.
413
            with set_forward_context(per_layer_attn_metadata,
414
415
                                     self.vllm_config,
                                     num_tokens=input_batch_size):
416
                ret_hidden_states = self.model(
417
                    input_ids=input_ids,
418
                    positions=self._get_positions(input_batch_size),
419
420
                    hidden_states=self.hidden_states[:input_batch_size],
                    inputs_embeds=inputs_embeds,
421
                )
422
                if self.method == "mtp":
423
424
425
426
                    last_hidden_states = ret_hidden_states
                    hidden_states = ret_hidden_states
                else:
                    last_hidden_states, hidden_states = ret_hidden_states
427
            hidden_states = hidden_states[:batch_size]
428
            logits = self.model.compute_logits(last_hidden_states[:batch_size])
429
            draft_token_ids = logits.argmax(dim=-1)
430
431
432
433
            draft_token_ids_list.append(draft_token_ids)

        # [batch_size, num_speculative_tokens]
        draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
434
        return draft_token_ids
435

436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
    def prepare_next_token_ids_cpu(
            self, sampled_token_ids: list[list[int]],
            requests: dict[str,
                           CachedRequestState], gpu_input_batch: InputBatch,
            num_scheduled_tokens: dict[str, int]) -> torch.Tensor:
        """
        This function is used to prepare the inputs for speculative decoding.
        It calculates the next token ids for each request based on the sampled
        token ids from the CPU. If a request has no sampled token ids (e.g.,
        during the initial decoding steps), it falls back to using the request
        state to get the next token id.
        """
        req_ids = gpu_input_batch.req_ids
        next_token_ids: list[int] = []
        for i, token_ids in enumerate(sampled_token_ids):
            if token_ids:
                # Common case.
                next_token_id = token_ids[-1]
            else:
                # Partial prefill (rare case).
                # Get the next token id from the request state.
                req_id = req_ids[i]
                req_state = requests[req_id]
                seq_len = (req_state.num_computed_tokens +
                           num_scheduled_tokens[req_id])
                next_token_id = req_state.get_token_id(seq_len)
            next_token_ids.append(next_token_id)
        next_token_ids = torch.tensor(next_token_ids,
                                      dtype=torch.int32,
                                      device=self.input_ids.device)
        return next_token_ids

    def prepare_next_token_ids_padded(self,
                               common_attn_metadata: CommonAttentionMetadata,
                               sampled_token_ids: torch.Tensor,
                               requests: dict[str, CachedRequestState],
                               gpu_input_batch: InputBatch,
                               discard_request_indices: torch.Tensor,
                               num_discarded_requests: int) -> \
                                tuple[torch.Tensor, torch.Tensor]:
        """
        This function is used to prepare the inputs for speculative decoding.
        It calculates the next token ids and the number of valid sampled tokens
        for each request, considering the "discarded" requests whose next token
        is not sampled and comes from `request.get_token_id()` instead.
        It also accounts for the rejected tokens in `sampled_token_ids`.
        This function must use device functions to operate on the inputs, and
        should not introduce any blocking CPU-GPU synchronization.
        """
        # TODO(Ben): Combine this into a custom fused kernel

        # Precompute get_token_id for when there is no valid next token
        num_reqs = gpu_input_batch.num_reqs
        self.backup_next_token_ids.np[:num_reqs] = np.array([
            requests[gpu_input_batch.req_ids[i]].get_token_id(
                common_attn_metadata.seq_lens_cpu[i].item())
            for i in range(num_reqs)
        ])
        self.backup_next_token_ids.copy_to_gpu(num_reqs)

        # Mask out the sampled tokens indices that should not be sampled.
        discard_sampled_tokens_req_indices = \
            discard_request_indices[:num_discarded_requests]

        valid_sampled_token_ids_gpu = sampled_token_ids.clone()
        valid_sampled_token_ids_gpu.index_fill_(
            0, discard_sampled_tokens_req_indices, -1)

        # Generate a mask for all valid tokens within those requests
        max_gen_len = sampled_token_ids.shape[-1]
        if max_gen_len == 1:
            valid_mask = torch.ones_like(valid_sampled_token_ids_gpu,
                                         dtype=torch.bool)
        else:
            valid_mask = (
                (valid_sampled_token_ids_gpu != -1) &
                (valid_sampled_token_ids_gpu < gpu_input_batch.vocab_size))

        # Count the number of valid tokens in each request
        valid_sampled_tokens_count = valid_mask.sum(dim=1)

        # Get the rightmost valid index per row
        last_valid_indices = valid_sampled_tokens_count - 1
        last_valid_indices_safe = torch.clamp(last_valid_indices, min=0)

        # Get last valid token from each row
        # (assume undefined state where there is no valid token)
        selected_tokens = torch.gather(
            valid_sampled_token_ids_gpu, 1,
            last_valid_indices_safe.unsqueeze(1)).squeeze(1)

        # Use last token if valid, pre-computed backup if not
        batch_size = valid_sampled_token_ids_gpu.shape[0]
        next_token_ids = torch.where(
            last_valid_indices != -1, selected_tokens,
            self.backup_next_token_ids.gpu[:batch_size])

        return next_token_ids, valid_sampled_tokens_count

    def prepare_inputs_padded(self,
                                common_attn_metadata: CommonAttentionMetadata,
                                spec_decode_metadata: SpecDecodeMetadata,
                                valid_sampled_tokens_count: torch.Tensor) -> \
                    tuple[CommonAttentionMetadata, torch.Tensor, torch.Tensor]:
        """
        This function is used to prepare the inputs for speculative decoding
        It updates the common_attn_metadata for speculative decoding,
        but does not consider the rejected tokens. Instead, all tokens
        are included as inputs to the speculator, with the rejected tokens
        used as padding and filtered out later by `token_indices_to_sample`.
        No blocking CPU operations should be introduced in this function.
        """
        num_draft_tokens_gpu = torch.cat([
            spec_decode_metadata.cu_num_draft_tokens[0:1],
            spec_decode_metadata.cu_num_draft_tokens[1:] -
            spec_decode_metadata.cu_num_draft_tokens[:-1]
        ])

        num_rejected_tokens_gpu = torch.where(
            num_draft_tokens_gpu > 0,
            num_draft_tokens_gpu + 1 - valid_sampled_tokens_count,
            torch.zeros_like(num_draft_tokens_gpu))

        query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu

        new_query_len_per_req = (query_start_loc_cpu[1:] -
                                 query_start_loc_cpu[:-1])

        total_num_tokens = query_start_loc_cpu[-1].item()
        token_indices = self.arange[:total_num_tokens]

        spec_common_attn_metadata = CommonAttentionMetadata(
            query_start_loc=common_attn_metadata.query_start_loc,
            seq_lens=common_attn_metadata.seq_lens,
            query_start_loc_cpu=query_start_loc_cpu,
            seq_lens_cpu=common_attn_metadata.seq_lens_cpu,
            num_computed_tokens_cpu=common_attn_metadata.
            num_computed_tokens_cpu,
            num_reqs=common_attn_metadata.num_reqs,
            num_actual_tokens=total_num_tokens,
            max_query_len=new_query_len_per_req.max().item(),
            max_seq_len=common_attn_metadata.seq_lens_cpu.max().item(),
            block_table_tensor=common_attn_metadata.block_table_tensor,
            slot_mapping=common_attn_metadata.slot_mapping[token_indices],
            causal=True,
        )

        token_indices_to_sample = common_attn_metadata.query_start_loc[1:] - 1 \
            - num_rejected_tokens_gpu

        return spec_common_attn_metadata, token_indices, token_indices_to_sample

588
589
590
591
592
593
594
595
596
597
598
    def propose_tree(
        self,
        batch_size: int,
        # [num_tokens, vocab_size]
        logits: torch.Tensor,
        # [num_tokens]
        positions: torch.Tensor,
        # [num_tokens, hidden_size]
        hidden_states: torch.Tensor,
        common_attn_metadata: CommonAttentionMetadata,
    ) -> list[torch.Tensor]:
599
        tree_attn_metadata_builder = \
600
            self.runner.attn_groups[0][0].get_metadata_builder()
601
602
603
        assert isinstance(tree_attn_metadata_builder,
                          TreeAttentionMetadataBuilder)

604
        total_num_drafts = self.cu_drafts_per_level[0]
605
606
        level_num_drafts = total_num_drafts
        # Sample a draft token for each child at the tree root level.
607
        num_children = self.child_drafts_per_level[0]
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
        if num_children == 1:
            draft_token_ids = logits.argmax(dim=-1).view(batch_size, -1)
        else:
            draft_token_ids = torch.topk(logits, num_children,
                                         dim=-1).indices.view(batch_size, -1)
        draft_token_ids_list = [draft_token_ids]
        draft_hidden_states = hidden_states.view(batch_size, 1, -1)

        # Initialize empty tensors for concatenation with the level outputs.
        tree_input_ids = torch.empty(0,
                                     device=self.input_ids.device,
                                     dtype=self.input_ids.dtype)
        tree_positions = torch.empty(0,
                                     device=self.positions.device,
                                     dtype=self.positions.dtype)
        tree_hidden_states = torch.empty(0,
                                         device=self.hidden_states.device,
                                         dtype=self.hidden_states.dtype)
        # Precompute the draft token positions.
        flattened_draft_positions = (
            positions.view(batch_size, -1) +
            self.tree_draft_pos_offsets[:batch_size, :])
        tree_depth = len(self.cu_drafts_per_level)
631
        for level in range(tree_depth - 1):
632
633
634
635
636
637
            # Get draft positions for RoPE.
            draft_positions = positions + (level + 1)
            exceeds_max_model_len = (positions +
                                     total_num_drafts) >= self.max_model_len
            # Mask out the position ids that exceed the max model length.
            # Otherwise, we may get out-of-range error in RoPE.
638
            draft_positions = torch.where(
639
640
641
                exceeds_max_model_len,
                0,
                draft_positions,
642
643
            ).view(batch_size, -1)

644
645
            if level_num_drafts > 1:
                # Repeat the positions for each draft at this level.
646
647
                draft_positions = draft_positions.repeat_interleave(
                    level_num_drafts, dim=1)
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663

            if num_children > 1:
                # Repeat draft hidden states for each child.
                draft_hidden_states = draft_hidden_states.repeat_interleave(
                    num_children, dim=1)

            # Concatenate the draft tokens, positions, and hidden states.
            tree_input_ids = torch.cat([tree_input_ids, draft_token_ids],
                                       dim=1)
            tree_positions = torch.cat([tree_positions, draft_positions],
                                       dim=1)
            tree_hidden_states = torch.cat(
                [tree_hidden_states, draft_hidden_states], dim=1)

            # Build new attention metadata for the next level of drafts.
            # This is necessary to support tree attention.
664
            query_len = total_num_drafts
665
666
667
668
669
670
671
672
673
            common_attn_metadata = replace(
                common_attn_metadata,
                query_start_loc=query_len * self.arange[:batch_size + 1],
                seq_lens=common_attn_metadata.seq_lens + level_num_drafts,
                num_actual_tokens=batch_size * query_len,
                max_query_len=query_len,
            )
            attn_metadata = tree_attn_metadata_builder.build_for_drafting(
                common_attn_metadata=common_attn_metadata,
674
                draft_index=level + 1,
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
            )

            # Apply new attention metadata to all layers.
            per_layer_attn_metadata = {}
            for layer_name in self.attn_layer_names:
                per_layer_attn_metadata[layer_name] = attn_metadata

            # Consider max model length.
            attn_metadata.max_seq_len = min(attn_metadata.max_seq_len,
                                            self.max_model_len)
            # For the requests that exceed the max model length, we set the
            # sequence length to 1 to minimize their overheads in attention.
            attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1)

            # Compute the slot mapping.
            query_positions = flattened_draft_positions[:, level:level +
                                                        query_len]
            block_numbers = query_positions // self.block_size
            block_ids = attn_metadata.block_table.gather(dim=1,
                                                         index=block_numbers)
            slot_mapping = (block_ids * self.block_size +
                            query_positions % self.block_size)
            # Mask out the slot mappings that exceed the max model length.
            # Otherwise, the KV cache will be inadvertently updated with the
            # padding tokens.
            slot_mapping[exceeds_max_model_len] = PADDING_SLOT_ID
            attn_metadata.slot_mapping = slot_mapping.view(-1)

            # Copy inputs to buffer for cudagraph.
            num_tokens = attn_metadata.num_actual_tokens
            input_ids = tree_input_ids.view(-1)
            self.input_ids[:num_tokens] = input_ids
            self.positions[:num_tokens] = tree_positions.view(-1)
            self.hidden_states[:num_tokens] = tree_hidden_states.view(
                num_tokens, -1)

            if self.use_cuda_graph and \
712
                    num_tokens <= self.cudagraph_batch_sizes[-1]:
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
                num_input_tokens = self.vllm_config.pad_for_cudagraph(
                    num_tokens)
            else:
                num_input_tokens = num_tokens
            # Run the model.
            with set_forward_context(per_layer_attn_metadata,
                                     self.vllm_config,
                                     num_tokens=num_input_tokens):
                last_hidden_states, hidden_states = self.model(
                    input_ids=self.input_ids[:num_input_tokens],
                    positions=self.positions[:num_input_tokens],
                    hidden_states=self.hidden_states[:num_input_tokens],
                    inputs_embeds=None,
                )

            # Get the output hidden states for the draft tokens.
            draft_hidden_states = hidden_states[:num_tokens].view(
                batch_size, query_len, -1)[:, -level_num_drafts:]
            draft_last_hidden_states = last_hidden_states[:num_tokens].view(
                batch_size, query_len, -1)[:, -level_num_drafts:]

            # Get the output logits for the draft tokens.
            logits = self.model.compute_logits(
                draft_last_hidden_states.reshape(batch_size * level_num_drafts,
737
                                                 -1))
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754

            # Sample a draft token for each child at the next tree level.
            num_children = self.child_drafts_per_level[level + 1]
            if num_children == 1:
                draft_token_ids = logits.argmax(dim=-1).view(batch_size, -1)
            else:
                draft_token_ids = torch.topk(logits, num_children,
                                             dim=-1).indices.view(
                                                 batch_size, -1)
            draft_token_ids_list.append(draft_token_ids)

            # Update the # drafts counters for the next tree level.
            level_num_drafts = self.cu_drafts_per_level[level +
                                                        1] - total_num_drafts
            total_num_drafts = self.cu_drafts_per_level[level + 1]
        return draft_token_ids_list

755
    def prepare_inputs(
756
757
        self,
        common_attn_metadata: CommonAttentionMetadata,
758
759
        sampled_token_ids: list[list[int]],
        num_draft_tokens: list[int],
760
761
    ) -> tuple[CommonAttentionMetadata, torch.Tensor]:
        """
762
        This function is used to prepare the inputs for speculative decoding.
763
764
765
766
767
768
        It updates to the common_attn_metadata to account for the rejected
        tokens (and newly sampled tokens). It also returns the token indices
        of the tokens that should be fed to the speculator.
        """
        # E.g.
        #  common_attn_metadata.query_start_loc{_cpu}:
769
        #       [0, q1, q1 + q2, q1 + q2 + q3]
770
771
772
773
774
775
        #  common_attn_metadata.seq_lens{_cpu}: [s1, s2, s3]
        #  num_rejected_tokens: [n1, n2, n3]
        # This function computes the intermediate values:
        #  num_tokens_per_req: [q1 - n1, q2 - n2, q3 - n3]
        # And returns:
        #  common_attn_metadata.query_start_loc{_cpu}:
776
        #       [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
777
        #  common_attn_metadata.seq_lens{_cpu}:
778
        #       [s1 - n1 + 1, s2 - n2 + 1, s3 - n3 + 1]
779
        #  token_indices: [0, 1, ..., q1 - n1 - 1,
780
781
        #                 q1, q1 + 1, ..., q1 + q2 - n2 - 1,
        #                 q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1]
782

783
784
785
786
787
788
789
        num_rejected_tokens = [
            n + 1 - len(sampled_token_ids[i]) if n > 0 else 0
            for i, n in enumerate(num_draft_tokens)
        ]
        num_rejected_tokens = torch.tensor(num_rejected_tokens,
                                           dtype=torch.int32)

790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
        device = common_attn_metadata.query_start_loc.device
        query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
        new_seq_lens_cpu = common_attn_metadata.seq_lens_cpu \
            - num_rejected_tokens

        # [0, q1, q1 + q2, q1 + q2 + q3] -> [q1, q2, q3]
        new_query_len_per_req = (query_start_loc_cpu[1:] -
                                 query_start_loc_cpu[:-1])
        # [q1, q2, q3] -> [q1 - n1, q2 - n2, q3 - n3]
        new_num_tokens_per_req = new_query_len_per_req - num_rejected_tokens
        new_num_tokens_per_req_np = new_num_tokens_per_req.numpy()

        # [q1 - n1, q2 - n2, q3 - n3] ->
        # [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
        new_query_start_loc_cpu = torch.zeros(
            query_start_loc_cpu.shape,
806
            dtype=torch.int32,
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
            pin_memory=is_pin_memory_available())
        new_query_start_loc_np = new_query_start_loc_cpu.numpy()
        np.cumsum(new_num_tokens_per_req_np, out=new_query_start_loc_np[1:])

        total_num_tokens = new_query_start_loc_np[-1]
        # Example assuming num_tokens_per_req_np = [2, 4, 3]
        # this implies that `new_query_start_locs` is:
        # [0, 2, 6, 9] ->
        # [0, 0, 2, 2, 2, 2, 6, 6, 6]
        #  _r1_  ____r2____  ___r3__
        new_query_start_locs_expanded = np.repeat(new_query_start_loc_np[:-1],
                                                  new_num_tokens_per_req_np)
        # [0, 1, 2, 3, 4, 5, 6, 7, 8] ->
        # [0, 1, 0, 1, 2, 3, 0, 1, 2]
        #  _r1_  ____r2____  ___r3__
        token_offests = self.token_arange_np[:total_num_tokens] \
            - new_query_start_locs_expanded

        # Expand starting positions to match token pattern
        # [0, q1, q1 + q2] ->
        # [0, 0, q1, q1, q1, q1, q1 + q2, q1 + q2, q1 + q2]
        #  _r1_  _____r2_______  ___________r3____________
        old_query_start_locs_expanded = np.repeat(
            query_start_loc_cpu[:-1].numpy(), new_num_tokens_per_req_np)
        # Final token indices are:
832
833
834
        # [0, 1,                                // req 1
        #  q1 + 0, q1 + 1, q1 + 2, q1 + 3,       // req 2
        #  q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
        token_indices_np = token_offests + old_query_start_locs_expanded
        token_indices = torch.from_numpy(token_indices_np).to(
            device, non_blocking=True)

        spec_common_attn_metadata = CommonAttentionMetadata(
            query_start_loc=new_query_start_loc_cpu.to(device,
                                                       non_blocking=True),
            seq_lens=new_seq_lens_cpu.to(device, non_blocking=True),
            query_start_loc_cpu=new_query_start_loc_cpu,
            seq_lens_cpu=new_seq_lens_cpu,
            num_computed_tokens_cpu=common_attn_metadata.
            num_computed_tokens_cpu,
            num_reqs=common_attn_metadata.num_reqs,
            num_actual_tokens=total_num_tokens,
            max_query_len=new_query_len_per_req.max().item(),
850
            max_seq_len=new_seq_lens_cpu.max().item(),
851
852
            block_table_tensor=common_attn_metadata.block_table_tensor,
            slot_mapping=common_attn_metadata.slot_mapping[token_indices],
853
            causal=True,
854
        )
855
856

        return spec_common_attn_metadata, token_indices
857

858
859
860
861
862
    def get_model_name(self, model: nn.Module) -> str:
        if hasattr(model, 'module'):  # multi-GPU
            model = model.module
        return model.__class__.__name__

863
    def load_model(self, target_model: nn.Module) -> None:
864
865
        draft_model_config = \
            self.vllm_config.speculative_config.draft_model_config
866
867
        target_attn_layer_names = set(
            get_layers_from_vllm_config(self.vllm_config, Attention).keys())
868
869
870
871
        # FIXME: support hybrid kv for draft model
        target_indexer_layer_names = set(
            get_layers_from_vllm_config(self.vllm_config,
                                        DeepseekV32IndexerCache).keys())
872

873
874
875
876
        from vllm.compilation.backends import set_model_tag
        with set_model_tag("eagle_head"):
            self.model = get_model(vllm_config=self.vllm_config,
                                   model_config=draft_model_config)
877

878
879
880
        draft_attn_layer_names = (
            get_layers_from_vllm_config(self.vllm_config, Attention).keys() -
            target_attn_layer_names)
881
882
883
884
        indexer_layers = get_layers_from_vllm_config(self.vllm_config,
                                                     DeepseekV32IndexerCache)
        draft_indexer_layer_names = (indexer_layers.keys() -
                                     target_indexer_layer_names)
885
        self.attn_layer_names = list(draft_attn_layer_names)
886
887
888
889
890
891
892
893
894
895
896
897
898
899
        self.indexer_layer_names = list(draft_indexer_layer_names)

        if self.indexer_layer_names:
            first_layer = self.indexer_layer_names[0]
            self.draft_indexer_metadata_builder = (
                indexer_layers[first_layer].get_attn_backend().get_builder_cls(
                )(
                    indexer_layers[first_layer].get_kv_cache_spec(),
                    self.indexer_layer_names,
                    self.vllm_config,
                    self.device,
                ))
        else:
            self.draft_indexer_metadata_builder = None
900

901
        if self.supports_mm_inputs:
902
903
904
905
906
907
908
909
910
911
912
            # Even if the target model is multimodal, we can also use
            # text-only draft models
            try:
                dummy_input_ids = torch.tensor([[1]],
                                               device=self.input_ids.device)
                self.model.get_input_embeddings(dummy_input_ids,
                                                multimodal_embeddings=None)
            except (NotImplementedError, AttributeError, TypeError):
                logger.warning(
                    "Draft model does not support multimodal inputs, "
                    "falling back to text-only mode")
913
                self.supports_mm_inputs = False
914

915
916
        if supports_multimodal(target_model):
            # handle multimodality
917
918
919
920
921
922
923
            if (self.get_model_name(target_model) ==
                    "Qwen2_5_VLForConditionalGeneration"):
                self.model.config.image_token_index = (
                    target_model.config.image_token_id)
            else:
                self.model.config.image_token_index = (
                    target_model.config.image_token_index)
924
925
926
            target_language_model = target_model.get_language_model()
        else:
            target_language_model = target_model
927
        # share embed_tokens with the target model if needed
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
        if get_pp_group().world_size == 1:
            if hasattr(target_language_model.model, 'embed_tokens'):
                target_embed_tokens = target_language_model.model.embed_tokens
            elif hasattr(target_language_model.model, 'embedding'):
                target_embed_tokens = target_language_model.model.embedding
            else:
                raise AttributeError(
                    "Target model does not have 'embed_tokens' or 'embedding' "
                    "attribute")

            # Check if shapes match and we found the embedding
            eagle_shape = self.model.model.embed_tokens.weight.shape
            target_shape = target_embed_tokens.weight.shape
            if eagle_shape == target_shape:
                logger.info(
                    "Assuming the EAGLE head shares the same vocab embedding"
                    " with the target model.")
                del self.model.model.embed_tokens
                self.model.model.embed_tokens = target_embed_tokens
            else:
                logger.info(
                    "The EAGLE head's vocab embedding will be loaded separately"
                    " from the target model.")
951
        else:
952
            logger.info(
953
954
                "The EAGLE head's vocab embedding will be loaded separately"
                " from the target model.")
955
956
957
958

        # share lm_head with the target model if needed
        # some model definition do not define lm_head explicitly
        # and reuse embed_tokens for lm_head, e.g., CohereForCausalLM
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
        if self.vllm_config.speculative_config.method != "eagle3":
            if hasattr(target_language_model, "lm_head"):
                logger.info(
                    "Loading EAGLE LM head weights from the target model.")
                self.model.lm_head = target_language_model.lm_head
        else:
            if (hasattr(self.model, "lm_head")
                    and hasattr(target_language_model, "lm_head")
                    and self.model.lm_head.weight.shape
                    == target_language_model.lm_head.weight.shape):
                logger.info("Assuming the EAGLE head shares the same lm_head"
                            " with the target model.")
                del self.model.lm_head
                self.model.lm_head = target_language_model.lm_head
            else:
                logger.info(
                    "The EAGLE head's lm_head will be loaded separately"
                    " from the target model.")
977

978
979
980
981
982
983
984
    @torch.inference_mode()
    def dummy_run(
        self,
        num_tokens: int,
    ) -> None:
        with set_forward_context(None, self.vllm_config,
                                 num_tokens=num_tokens):
985
            if self.supports_mm_inputs:
986
987
988
989
990
991
                input_ids = None
                inputs_embeds = self.inputs_embeds[:num_tokens]
            else:
                input_ids = self.input_ids[:num_tokens]
                inputs_embeds = None

992
            self.model(
993
                input_ids=input_ids,
994
                positions=self._get_positions(num_tokens),
995
996
                hidden_states=self.hidden_states[:num_tokens],
                inputs_embeds=inputs_embeds,
997
            )
998

999
1000
1001
    def _get_attention_metadata_builder(
            self) -> list[AttentionMetadataBuilder]:
        """Find and return the attention metadata builders for EAGLE layers.
1002

1003
1004
        Returns:
            The metadata builders for EAGLE layers.
1005

1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
        Raises:
            AssertionError: If no metadata builders are found for EAGLE layers.
        """
        builder = None
        chosen_layer = self.attn_layer_names[0]

        for kv_cache_group in self.runner.attn_groups:
            for attn_group in kv_cache_group:
                if chosen_layer in attn_group.layer_names:
                    builder = attn_group.get_metadata_builder()
                    break
            if builder is not None:
                break

        assert builder is not None, (
            "Failed to find attention metadata builder for EAGLE layers.")
        return builder

1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
    def validate_same_kv_cache_group(self,
                                     kv_cache_config: KVCacheConfig) -> None:
        """
        Validate that all eagle layers belong to the same KVCacheGroup.
        Need this assumption to ensure all eagle layers can use the
        same AttentionMetadata.
        May extend to multiple AttentionMetadata in the future.
        """
        kv_cache_groups: dict[str, int] = {}
        for id, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups):
            for layer_name in kv_cache_group.layer_names:
                kv_cache_groups[layer_name] = id
        assert len(
            set([
                kv_cache_groups[layer_name]
                for layer_name in self.attn_layer_names
            ])
        ) == 1, "All eagle layers should belong to the same kv cache group"

1043

1044
1045
1046
1047
# NOTE(woosuk): Currently, the below code is not used and we always use argmax
# to sample the draft tokens. We will use this after we find a way to manage
# the draft prob tensor.
# Refer to https://github.com/vllm-project/vllm/pull/16899 for the details.
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
# FIXME(woosuk): The logic here is duplicated with the main sampling code.
# We should refactor this to reuse the same sampling implementation.
def compute_probs_and_sample_next_token(
    logits: torch.Tensor,
    sampling_metadata: SamplingMetadata,
) -> tuple[torch.Tensor, torch.Tensor]:
    if sampling_metadata.all_greedy:
        # For greedy requests, draft_probs is not used in rejection sampling.
        # Therefore, we can just return the logits.
        probs = logits
        next_token_ids = logits.argmax(dim=-1)
        return next_token_ids, probs

    is_greedy = sampling_metadata.temperature == -1
    temperature = torch.where(is_greedy, 1.0, sampling_metadata.temperature)
    logits.div_(temperature.view(-1, 1))
    probs = logits.softmax(dim=-1, dtype=torch.float32)

    # NOTE(woosuk): Currently, we ignore most of the sampling parameters in
    # generating the draft tokens. We only use the temperature. While this
    # could degrade the acceptance rate, it does not affect the distribution
    # of the generated tokens after rejection sampling.

    # TODO(woosuk): Consider seeds.
    q = torch.empty_like(probs)
    q.exponential_()
1074
1075
1076
    # NOTE(woosuk): We shouldn't use `probs.div_(q)` because the draft_probs
    # will be used later for rejection sampling.
    next_token_ids = probs.div(q).argmax(dim=-1).view(-1)
1077
1078
1079
1080
1081
1082
1083
1084
    if not sampling_metadata.all_random:
        greedy_token_ids = probs.argmax(dim=-1)
        next_token_ids = torch.where(
            is_greedy,
            greedy_token_ids,
            next_token_ids,
        )
    return next_token_ids, probs