eagle.py 25 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
from typing import Any, Optional
import numpy as np

6
7
8
import torch
import torch.nn as nn

9
10
from vllm.attention.layer import Attention
from vllm.config import (CompilationLevel, VllmConfig,
11
                         get_layers_from_vllm_config)
12
from vllm.distributed.parallel_state import get_pp_group
13
from vllm.forward_context import set_forward_context
14
from vllm.logger import init_logger
15
from vllm.model_executor.model_loader import get_model
16
from vllm.model_executor.models import supports_multimodal
17
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
zhuwenwen's avatar
zhuwenwen committed
18

19
20
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
21
from vllm.v1.attention.backends.mla.common import MLACommonMetadata, MLACommonDecodeMetadata
zhuwenwen's avatar
zhuwenwen committed
22

23
from vllm.v1.kv_cache_interface import KVCacheConfig
24
from vllm.v1.sample.metadata import SamplingMetadata
Jiayi Yao's avatar
Jiayi Yao committed
25
from vllm.v1.spec_decode.utils import prepare_eagle_input_kernel
26

27
28
logger = init_logger(__name__)

29
30
PADDING_SLOT_ID = -1

31
32
33
34

class EagleProposer:

    def __init__(
35
36
37
38
        self,
        vllm_config: VllmConfig,
        device: torch.device,
        runner=None,
39
40
    ):
        self.vllm_config = vllm_config
41
42
43
        self.speculative_config = vllm_config.speculative_config
        self.draft_model_config = self.speculative_config.draft_model_config
        self.method = self.speculative_config.method
44

Jiayi Yao's avatar
Jiayi Yao committed
45
46
        self.runner = runner

47
        self.dtype = vllm_config.model_config.dtype
48
        self.max_model_len = vllm_config.model_config.max_model_len
49
        self.block_size = vllm_config.cache_config.block_size
50
51
52
53
54
55
56
57
        self.num_speculative_tokens = (
            self.speculative_config.num_speculative_tokens)
        self.max_num_tokens = (
            vllm_config.scheduler_config.max_num_batched_tokens)
        # We need to get the hidden size from the draft model config because
        # the draft model's hidden size can be different from the target model's
        # hidden size (e.g., Llama 3.3 70B).
        self.hidden_size = self.draft_model_config.get_hidden_size()
58

59
60
61
        self.use_cuda_graph = (self.vllm_config.compilation_config.level
                               == CompilationLevel.PIECEWISE and
                               not self.vllm_config.model_config.enforce_eager)
62
63
64
        self.use_full_cuda_graph = (
            self.use_cuda_graph
            and vllm_config.compilation_config.full_cuda_graph)
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
        self.cudagraph_batch_sizes = list(
            reversed(
                self.vllm_config.compilation_config.cudagraph_capture_sizes))

        # persistent buffers for cuda graph
        self.input_ids = torch.zeros(self.max_num_tokens,
                                     dtype=torch.int32,
                                     device=device)
        self.positions = torch.zeros(self.max_num_tokens,
                                     dtype=torch.int64,
                                     device=device)
        self.hidden_states = torch.zeros(
            (self.max_num_tokens, self.hidden_size),
            dtype=self.dtype,
            device=device)
王敏's avatar
王敏 committed
80

81
82
        # attention metadata captured in full cudagraph mode
        self.attn_metadata_cudagraph = None
83
84
85
86
87
88
        # We need +1 here because the arange is used to set query_start_loc,
        # which has one more element than batch_size.
        self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs +
                                   1,
                                   device=device,
                                   dtype=torch.int32)
89
90

    def propose(
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
        self,
        # [num_tokens]
        target_token_ids: torch.Tensor,
        # [num_tokens]
        target_positions: torch.Tensor,
        # [num_tokens, hidden_size]
        target_hidden_states: torch.Tensor,
        # [num_tokens]
        target_slot_mapping: torch.Tensor,
        # [batch_size]
        next_token_ids: torch.Tensor,
        # [batch_size + 1] starting with 0
        cu_num_tokens: torch.Tensor,
        # [batch_size, max_num_blocks_per_req]
        block_table: torch.Tensor,
        # [batch_size]
        num_rejected_tokens: list[int],
        # [batch_size]
        sampling_metadata: SamplingMetadata
王敏's avatar
王敏 committed
110
    ) -> torch.Tensor:
111
112
113
114
        num_tokens = target_token_ids.shape[0]
        batch_size = next_token_ids.shape[0]
        last_token_indices = cu_num_tokens[1:] - 1

115
116
117
118
119
120
        if self.method == "eagle3":
            assert isinstance(self.model, Eagle3LlamaForCausalLM)
            target_hidden_states = self.model.combine_hidden_states(
                target_hidden_states)
            assert target_hidden_states.shape[-1] == self.hidden_size

121
122
        # Shift the input ids by one token.
        # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
123
        self.input_ids[:num_tokens - 1] = target_token_ids[1:]
124
125
        # Replace the last token with the next token.
        # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
126
        self.input_ids[last_token_indices] = next_token_ids
127

128
129
130
        # FA requires seq_len to have dtype int32.
        seq_lens = (target_positions[last_token_indices] + 1).int()

Jiayi Yao's avatar
Jiayi Yao committed
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
        if self.method in ["eagle", "eagle3"]:
            # FIXME(woosuk): The below two ops cause synchronization. Optimize.
            max_seq_len = seq_lens.max().item()
            max_num_tokens = (cu_num_tokens[1:] -
                              cu_num_tokens[:-1]).max().item()
            attn_metadata = FlashAttentionMetadata(
                num_actual_tokens=num_tokens,
                max_query_len=max_num_tokens,
                query_start_loc=cu_num_tokens,
                max_seq_len=max_seq_len,
                seq_lens=seq_lens,
                block_table=block_table,
                slot_mapping=target_slot_mapping,
                # TODO(woosuk): Support cascade attention.
                use_cascade=False,
                common_prefix_len=0,
                cu_prefix_query_lens=None,
                prefix_kv_lens=None,
                suffix_kv_lens=None,
            )
        elif self.method == "deepseek_mtp":
            query_lens = cu_num_tokens[1:] - cu_num_tokens[:-1]
            max_query_len = query_lens.max().item()

            common_attn_metadata = CommonAttentionMetadata(
156
                query_start_loc=cu_num_tokens,
157
                seq_lens=seq_lens,
158
159
160
                num_reqs=batch_size,
                num_actual_tokens=num_tokens,
                max_query_len=max_query_len,
zhuwenwen's avatar
zhuwenwen committed
161
                num_rejected_tokens=num_rejected_tokens,
162
            )
Jiayi Yao's avatar
Jiayi Yao committed
163
164
165
166

            assert self.runner is not None

            # FIXME: need to consider multiple kv_cache_groups
zhuwenwen's avatar
zhuwenwen committed
167
            attn_metadata = self.runner.attn_metadata_builders[0].build(
Jiayi Yao's avatar
Jiayi Yao committed
168
                common_prefix_len=0,
169
                common_attn_metadata=common_attn_metadata
170
            )
Jiayi Yao's avatar
Jiayi Yao committed
171
172
173
        else:
            raise ValueError(f"Unsupported method: {self.method}")

174
175
176
177
178
        # At this moment, we assume all eagle layers belong to the same KV
        # cache group, thus using the same attention metadata.
        per_layer_attn_metadata = {}
        for layer_name in self.attn_layer_names:
            per_layer_attn_metadata[layer_name] = attn_metadata
179
        if self.use_cuda_graph and \
180
            num_tokens <= self.cudagraph_batch_sizes[-1]:
181
182
183
184
185
            num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
        else:
            num_input_tokens = num_tokens
        # copy inputs to buffer for cudagraph
        self.positions[:num_tokens] = target_positions
186
        self.hidden_states[:num_tokens] = target_hidden_states
187

188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
        if (self.use_full_cuda_graph
                and num_tokens <= self.cudagraph_batch_sizes[-1]):
            assert self.attn_metadata_cudagraph
            if self.method in ["eagle", "eagle3"]:
                self.attn_metadata_cudagraph.seq_lens[:batch_size] = (
                    attn_metadata.seq_lens)
                self.attn_metadata_cudagraph.slot_mapping[:num_tokens] = (
                    attn_metadata.slot_mapping)
                self.attn_metadata_cudagraph.query_start_loc[:batch_size + 1] = (
                    attn_metadata.query_start_loc)
                self.attn_metadata_cudagraph.block_table[:batch_size] = (
                    attn_metadata.block_table)
            elif self.method == "deepseek_mtp":
                self.attn_metadata_cudagraph.num_actual_tokens = (
                    attn_metadata.num_actual_tokens)
                self.attn_metadata_cudagraph.query_start_loc[:batch_size + 1] = (
                    attn_metadata.query_start_loc)
                self.attn_metadata_cudagraph.slot_mapping[:num_tokens] = (
                    attn_metadata.slot_mapping)
                self.attn_metadata_cudagraph.num_decodes = (
                    attn_metadata.num_decodes)
                self.attn_metadata_cudagraph.num_decode_tokens = (
                    attn_metadata.num_decode_tokens)
                self.attn_metadata_cudagraph.num_prefills = (
                    attn_metadata.num_prefills)
王敏's avatar
王敏 committed
213

214
215
216
217
218
219
                if attn_metadata.decode is not None:
                    self.attn_metadata_cudagraph.decode.block_table[:attn_metadata.num_decode_tokens] = (
                        attn_metadata.decode.block_table)
                    self.attn_metadata_cudagraph.decode.seq_lens[:attn_metadata.num_decode_tokens] = (
                        attn_metadata.decode.seq_lens)

220
        with set_forward_context(per_layer_attn_metadata,
221
222
                                 self.vllm_config,
                                 num_tokens=num_input_tokens):
Jiayi Yao's avatar
Jiayi Yao committed
223
224
225
226
            ret_hidden_states = self.model(
                self.input_ids[:num_input_tokens],
                self.positions[:num_input_tokens],
                self.hidden_states[:num_input_tokens],
227
            )
Jiayi Yao's avatar
Jiayi Yao committed
228
229
230
231
            if self.method == "deepseek_mtp":
                last_hidden_states = ret_hidden_states
            else:
                last_hidden_states, hidden_states = ret_hidden_states
232
        sample_hidden_states = last_hidden_states[last_token_indices]
233
234
        logits = self.model.compute_logits(sample_hidden_states, None)

王敏's avatar
王敏 committed
235
        draft_token_ids = torch.argmax(logits, dim=-1)
236

237
238
        # Early exit if there is only one draft token to be generated.
        if self.num_speculative_tokens == 1:
239
            # [batch_size, 1]
王敏's avatar
王敏 committed
240
            return draft_token_ids.view(-1, 1)
241

Jiayi Yao's avatar
Jiayi Yao committed
242
243
244
245
        # TODO: Currently, MTP module released by deepseek only has
        # one layer. Adapt this code to support multiple layers once
        # there's a multi-layer MTP module.

246
247
248
249
        # Generate the remaining draft tokens.
        draft_token_ids_list = [draft_token_ids]

        positions = target_positions[last_token_indices]
250
251
252
253
254
255

        if self.method == "deepseek_mtp":
            hidden_states = last_hidden_states[last_token_indices]
        else:
            hidden_states = hidden_states[last_token_indices]

256
        if self.use_cuda_graph and \
王敏's avatar
王敏 committed
257
                batch_size <= self.cudagraph_batch_sizes[-1]:
258
259
260
            input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size)
        else:
            input_batch_size = batch_size
261
262
        attn_metadata.num_actual_tokens = batch_size
        attn_metadata.max_query_len = 1
263
        attn_metadata.query_start_loc = self.arange[:batch_size + 1]
264
265
266
267
268
269
270
271

        if isinstance(attn_metadata, MLACommonMetadata):
            attn_metadata.num_decodes = batch_size
            attn_metadata.num_decode_tokens = batch_size
            attn_metadata.num_prefills = 0
            block_table = self.runner.attn_metadata_builders[0].block_table.get_device_tensor()[:batch_size, ...]
            attn_metadata.decode = self.runner.attn_metadata_builders[0]._build_decode(
                block_table_tensor=block_table,
王敏's avatar
王敏 committed
272
                seq_lens=seq_lens,
273
274
            )

275
        for i in range(self.num_speculative_tokens - 1):
276
            # Update the inputs.
277
278
279
            # cast to int32 is crucial when eagle model is compiled.
            # tensor.argmax() returns int64 by default.
            input_ids = draft_token_ids_list[-1].int()
280
            positions += 1
281
282
283
284
285
286
287
288
289
290
291
292
293

            # NOTE(woosuk): We should handle the case where the draft model
            # generates tokens beyond the max model length. Since it is complex
            # to remove such requests from the batch, we keep them in the batch
            # but adjust the position ids and slot mappings to avoid the
            # out-of-range access during the model execution. The draft tokens
            # generated with this adjustment should be ignored.
            exceeds_max_model_len = positions >= self.max_model_len
            # Mask out the position ids that exceed the max model length.
            # Otherwise, we may get out-of-range error in RoPE.
            clamped_positions = torch.where(exceeds_max_model_len, 0,
                                            positions)

294
295
296
297
298
299
300
301
302
303
            if isinstance(attn_metadata, MLACommonMetadata):
                attn_metadata.decode.seq_lens += 1
            else:
                attn_metadata.seq_lens += 1

                # Increment the sequence lengths.
                attn_metadata.max_seq_len += 1
                # Consider max model length.
                attn_metadata.max_seq_len = min(attn_metadata.max_seq_len,
                                                self.max_model_len)
王敏's avatar
王敏 committed
304

305
306
307
                # For the requests that exceed the max model length, we set the
                # sequence length to 1 to minimize their overheads in attention.
                attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1)
308

309
            # Compute the slot mapping.
310
            block_numbers = clamped_positions // self.block_size
311
            block_ids = block_table.gather(dim=1,
312
                                        index=block_numbers.view(-1, 1))
313
314
            block_ids = block_ids.view(-1)
            attn_metadata.slot_mapping = (block_ids * self.block_size +
315
                                        clamped_positions % self.block_size)
316
317
318
319
320
            # Mask out the slot mappings that exceed the max model length.
            # Otherwise, the KV cache will be inadvertently updated with the
            # padding tokens.
            attn_metadata.slot_mapping.masked_fill_(exceeds_max_model_len,
                                                    PADDING_SLOT_ID)
321

322
323
324
            # copy inputs to buffer for cudagraph
            self.input_ids[:batch_size] = input_ids
            self.positions[:batch_size] = clamped_positions
325
            self.hidden_states[:batch_size] = hidden_states
326

327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
            if (self.use_full_cuda_graph
                    and batch_size <= self.cudagraph_batch_sizes[-1]):
                assert self.attn_metadata_cudagraph
                if self.method in ["eagle", "eagle3"]:
                    self.attn_metadata_cudagraph.seq_lens[:batch_size] = (
                        attn_metadata.seq_lens)
                    self.attn_metadata_cudagraph.slot_mapping[:batch_size] = (
                        attn_metadata.slot_mapping)
                    if i == 0:
                        self.attn_metadata_cudagraph.query_start_loc[:batch_size +
                                                                    1] = (
                                                                        attn_metadata
                                                                        .
                                                                        query_start_loc
                                                                    )
                        self.attn_metadata_cudagraph.block_table[:batch_size] = (
                            attn_metadata.block_table)
                elif self.method == "deepseek_mtp":
                    self.attn_metadata_cudagraph.num_actual_tokens = (
                        attn_metadata.num_actual_tokens)
                    self.attn_metadata_cudagraph.slot_mapping[:attn_metadata.num_decode_tokens] = (
                        attn_metadata.slot_mapping)
                    self.attn_metadata_cudagraph.num_decodes = (
                        attn_metadata.num_decodes)
                    self.attn_metadata_cudagraph.num_decode_tokens = (
                        attn_metadata.num_decode_tokens)
                    self.attn_metadata_cudagraph.num_prefills = (
                        attn_metadata.num_prefills)
                    self.attn_metadata_cudagraph.decode.seq_lens[:attn_metadata.num_decode_tokens] = (
                        attn_metadata.decode.seq_lens)
王敏's avatar
王敏 committed
357

358
359
360
361
362
363
                    if i == 0:
                        self.attn_metadata_cudagraph.query_start_loc[:batch_size + 1] = (
                            attn_metadata.query_start_loc)
                        self.attn_metadata_cudagraph.decode.block_table[:attn_metadata.num_decode_tokens] = (
                            attn_metadata.decode.block_table)

364
            # Run the model.
365
            with set_forward_context(per_layer_attn_metadata,
366
367
                                     self.vllm_config,
                                     num_tokens=input_batch_size):
368
                ret_hidden_states = self.model(
Jiayi Yao's avatar
Jiayi Yao committed
369
370
371
                    self.input_ids[:input_batch_size],
                    self.positions[:input_batch_size],
                    self.hidden_states[:input_batch_size],
372
                )
373
374
375
376
377
378
379
                if self.method == "deepseek_mtp":
                    last_hidden_states = ret_hidden_states
                    hidden_states = last_hidden_states[:batch_size]
                else:
                    last_hidden_states, hidden_states = ret_hidden_states
                    hidden_states = hidden_states[:batch_size]

380
381
            logits = self.model.compute_logits(last_hidden_states[:batch_size],
                                               None)
382

王敏's avatar
王敏 committed
383
            # # TODO(wenlong): get more than one token for tree attention
384
            draft_token_ids = logits.argmax(dim=-1)
385
386
387
388
            draft_token_ids_list.append(draft_token_ids)

        # [batch_size, num_speculative_tokens]
        draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
389

王敏's avatar
王敏 committed
390
        return draft_token_ids
391
392
393

    @staticmethod
    def prepare_inputs(
394
395
396
397
398
        # [batch_size + 1]
        cu_target_query_lens: torch.Tensor,
        # [batch_size]
        num_rejected_tokens: torch.Tensor,
        num_tokens: int,
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
    ) -> tuple[torch.Tensor, torch.Tensor]:
        # cu_target_query_lens: [0, a, a + b, a + b + c]
        # num_rejected_tokens: [n1, n2, n3]
        # num_tokens_per_req: [a - n1, b - n2, c - n3]
        # cu_num_tokens: [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3]
        # token_indices: [0, 1, ..., a - n1 - 1,
        #                 a, a + 1, ..., a + b - n2 - 1,
        #                 a + b, a + b + 1, ..., a + b + c - n3 - 1]

        # [0, a, a + b, a + b + c] -> [a, b, c]
        query_len_per_req = (cu_target_query_lens[1:] -
                             cu_target_query_lens[:-1])
        # [a, b, c] -> [a - n1, b - n2, c - n3]
        num_tokens_per_req = query_len_per_req - num_rejected_tokens

414
415
        # [a - n1, b - n2, c - n3] ->
        # [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3]
416
        cu_num_tokens = torch.zeros_like(cu_target_query_lens)
417
418
419
420
        torch.cumsum(num_tokens_per_req, dim=0, out=cu_num_tokens[1:])
        token_indices = torch.empty(
            num_tokens,
            dtype=torch.int32,
421
            device=cu_target_query_lens.device,
422
423
424
        )
        batch_size = num_rejected_tokens.shape[0]
        BLOCK_SIZE = 1024
425
        prepare_eagle_input_kernel[(batch_size, )](
426
427
428
429
430
431
432
433
            token_indices,
            cu_target_query_lens,
            cu_num_tokens,
            BLOCK_SIZE=BLOCK_SIZE,
        )
        return cu_num_tokens, token_indices

    def load_model(self, target_model: nn.Module) -> None:
434
435
        draft_model_config = \
            self.vllm_config.speculative_config.draft_model_config
436
437
        target_attn_layer_names = set(
            get_layers_from_vllm_config(self.vllm_config, Attention).keys())
438

439
440
441
442
        from vllm.compilation.backends import set_model_tag
        with set_model_tag("eagle_head"):
            self.model = get_model(vllm_config=self.vllm_config,
                                   model_config=draft_model_config)
443

444
        draft_attn_layer_names = (
445
446
            get_layers_from_vllm_config(self.vllm_config, Attention).keys() -
            target_attn_layer_names)
447
448

        self.attn_layer_names = list(draft_attn_layer_names)
449

450
451
452
453
454
455
456
        if supports_multimodal(target_model):
            # handle multimodality
            self.model.config.image_token_index = (
                target_model.config.image_token_index)
            target_language_model = target_model.get_language_model()
        else:
            target_language_model = target_model
457
        # share embed_tokens with the target model if needed
458
        if get_pp_group().world_size == 1 \
459
460
            and self.method != "deepseek_mtp" \
            and self.model.model.embed_tokens.weight.shape \
461
                == target_language_model.model.embed_tokens.weight.shape:
462
            logger.info(
463
                "Assuming the EAGLE head shares the same vocab embedding" \
464
465
                " with the target model."
            )
466
            del self.model.model.embed_tokens
467
468
            self.model.model.embed_tokens = (
                target_language_model.model.embed_tokens)
469
        else:
470
            logger.info(
471
472
                "The EAGLE head's vocab embedding will be loaded separately" \
                " from the target model."
473
474
475
476
477
478
            )

        # share lm_head with the target model if needed
        # some model definition do not define lm_head explicitly
        # and reuse embed_tokens for lm_head, e.g., CohereForCausalLM
        if self.vllm_config.speculative_config.method != "eagle3" and \
479
                hasattr(target_language_model, "lm_head"):
480
            logger.info("Loading EAGLE LM head weights from the target model.")
481
            self.model.lm_head = target_language_model.lm_head
482

483
484
    @torch.inference_mode()
    def dummy_run(
485
486
487
        self,
        num_tokens: int,
        attn_metadata: Optional[dict[str, Any]] = None,
488
    ) -> None:
489
490
491
492
493
        if attn_metadata is not None and self.attn_metadata_cudagraph is None:
            self.attn_metadata_cudagraph = attn_metadata[
                self.attn_layer_names[0]]
        with set_forward_context(attn_metadata,
                                 self.vllm_config,
494
                                 num_tokens=num_tokens):
495
            self.model(
Jiayi Yao's avatar
Jiayi Yao committed
496
497
498
                self.input_ids[:num_tokens],
                self.positions[:num_tokens],
                self.hidden_states[:num_tokens],
499
            )
500

501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
    def validate_same_kv_cache_group(self,
                                     kv_cache_config: KVCacheConfig) -> None:
        """
        Validate that all eagle layers belong to the same KVCacheGroup.
        Need this assumption to ensure all eagle layers can use the
        same AttentionMetadata.
        May extend to multiple AttentionMetadata in the future.
        """
        kv_cache_groups: dict[str, int] = {}
        for id, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups):
            for layer_name in kv_cache_group.layer_names:
                kv_cache_groups[layer_name] = id
        assert len(
            set([
                kv_cache_groups[layer_name]
                for layer_name in self.attn_layer_names
            ])
        ) == 1, "All eagle layers should belong to the same kv cache group"
519
520


521
522
523
524
# NOTE(woosuk): Currently, the below code is not used and we always use argmax
# to sample the draft tokens. We will use this after we find a way to manage
# the draft prob tensor.
# Refer to https://github.com/vllm-project/vllm/pull/16899 for the details.
525
526
527
# FIXME(woosuk): The logic here is duplicated with the main sampling code.
# We should refactor this to reuse the same sampling implementation.
def compute_probs_and_sample_next_token(
528
529
    logits: torch.Tensor,
    sampling_metadata: SamplingMetadata,
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
) -> tuple[torch.Tensor, torch.Tensor]:
    if sampling_metadata.all_greedy:
        # For greedy requests, draft_probs is not used in rejection sampling.
        # Therefore, we can just return the logits.
        probs = logits
        next_token_ids = logits.argmax(dim=-1)
        return next_token_ids, probs

    is_greedy = sampling_metadata.temperature == -1
    temperature = torch.where(is_greedy, 1.0, sampling_metadata.temperature)
    logits.div_(temperature.view(-1, 1))
    probs = logits.softmax(dim=-1, dtype=torch.float32)

    # NOTE(woosuk): Currently, we ignore most of the sampling parameters in
    # generating the draft tokens. We only use the temperature. While this
    # could degrade the acceptance rate, it does not affect the distribution
    # of the generated tokens after rejection sampling.

    # TODO(woosuk): Consider seeds.
    q = torch.empty_like(probs)
    q.exponential_()
551
552
553
    # NOTE(woosuk): We shouldn't use `probs.div_(q)` because the draft_probs
    # will be used later for rejection sampling.
    next_token_ids = probs.div(q).argmax(dim=-1).view(-1)
554
555
556
557
558
559
560
561
    if not sampling_metadata.all_random:
        greedy_token_ids = probs.argmax(dim=-1)
        next_token_ids = torch.where(
            is_greedy,
            greedy_token_ids,
            next_token_ids,
        )
    return next_token_ids, probs