eagle.py 25 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from typing import Any, Optional
4
import numpy as np
5

6
7
8
import torch
import torch.nn as nn

9
10
from vllm.attention.layer import Attention
from vllm.config import (CompilationLevel, VllmConfig,
11
                         get_layers_from_vllm_config)
12
from vllm.distributed.parallel_state import get_pp_group
13
from vllm.forward_context import set_forward_context
14
from vllm.logger import init_logger
15
from vllm.model_executor.model_loader import get_model
16
from vllm.model_executor.models import supports_multimodal
17
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
zhuwenwen's avatar
zhuwenwen committed
18

19
from vllm.utils import is_pin_memory_available
20
21
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
22
from vllm.v1.attention.backends.mla.common import MLACommonMetadata, MLACommonDecodeMetadata
zhuwenwen's avatar
zhuwenwen committed
23

24
from vllm.v1.kv_cache_interface import KVCacheConfig
25
26
from vllm.v1.sample.metadata import SamplingMetadata

27
28
logger = init_logger(__name__)

29
30
PADDING_SLOT_ID = -1

31
32
33
34
35
36
37

class EagleProposer:

    def __init__(
        self,
        vllm_config: VllmConfig,
        device: torch.device,
Jiayi Yao's avatar
Jiayi Yao committed
38
        runner=None,
39
40
    ):
        self.vllm_config = vllm_config
41
42
43
        self.speculative_config = vllm_config.speculative_config
        self.draft_model_config = self.speculative_config.draft_model_config
        self.method = self.speculative_config.method
44

Jiayi Yao's avatar
Jiayi Yao committed
45
        self.runner = runner
46
        self.dtype = vllm_config.model_config.dtype
47
        self.max_model_len = vllm_config.model_config.max_model_len
48
        self.block_size = vllm_config.cache_config.block_size
49
50
51
52
        self.num_speculative_tokens = (
            self.speculative_config.num_speculative_tokens)
        self.max_num_tokens = (
            vllm_config.scheduler_config.max_num_batched_tokens)
53
        self.token_arange_np = np.arange(self.max_num_tokens)
54
55
56
57
        # We need to get the hidden size from the draft model config because
        # the draft model's hidden size can be different from the target model's
        # hidden size (e.g., Llama 3.3 70B).
        self.hidden_size = self.draft_model_config.get_hidden_size()
58

59
60
61
        self.use_cuda_graph = (self.vllm_config.compilation_config.level
                               == CompilationLevel.PIECEWISE and
                               not self.vllm_config.model_config.enforce_eager)
62
63
64
        self.use_full_cuda_graph = (
            self.use_cuda_graph
            and vllm_config.compilation_config.full_cuda_graph)
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
        self.cudagraph_batch_sizes = list(
            reversed(
                self.vllm_config.compilation_config.cudagraph_capture_sizes))

        # persistent buffers for cuda graph
        self.input_ids = torch.zeros(self.max_num_tokens,
                                     dtype=torch.int32,
                                     device=device)
        self.positions = torch.zeros(self.max_num_tokens,
                                     dtype=torch.int64,
                                     device=device)
        self.hidden_states = torch.zeros(
            (self.max_num_tokens, self.hidden_size),
            dtype=self.dtype,
            device=device)
80
81
        # attention metadata captured in full cudagraph mode
        self.attn_metadata_cudagraph = None
82
83
84
85
86
87
        # We need +1 here because the arange is used to set query_start_loc,
        # which has one more element than batch_size.
        self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs +
                                   1,
                                   device=device,
                                   dtype=torch.int32)
88
89
90
91
92
93
94
95
96
97
98

    def propose(
        self,
        # [num_tokens]
        target_token_ids: torch.Tensor,
        # [num_tokens]
        target_positions: torch.Tensor,
        # [num_tokens, hidden_size]
        target_hidden_states: torch.Tensor,
        # [batch_size]
        next_token_ids: torch.Tensor,
99
        common_attn_metadata: CommonAttentionMetadata,
100
        sampling_metadata: SamplingMetadata,
101
    ) -> torch.Tensor:
102
103
        num_tokens = target_token_ids.shape[0]
        batch_size = next_token_ids.shape[0]
104
        last_token_indices = common_attn_metadata.query_start_loc[1:] - 1
105

106
107
108
109
110
111
        if self.method == "eagle3":
            assert isinstance(self.model, Eagle3LlamaForCausalLM)
            target_hidden_states = self.model.combine_hidden_states(
                target_hidden_states)
            assert target_hidden_states.shape[-1] == self.hidden_size

112
113
        # Shift the input ids by one token.
        # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
114
        self.input_ids[:num_tokens - 1] = target_token_ids[1:]
115
116
        # Replace the last token with the next token.
        # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
117
        self.input_ids[last_token_indices] = next_token_ids
118
119
        
        seq_lens = (target_positions[last_token_indices] + 1).int()
120

121
        assert self.runner is not None
Jiayi Yao's avatar
Jiayi Yao committed
122

123
124
125
126
127
128
        # FIXME: need to consider multiple kv_cache_groups
        attn_metadata = self.runner.attn_metadata_builders[0].build(
            common_prefix_len=0,
            common_attn_metadata=common_attn_metadata,
            fast_build=True,
        )
Jiayi Yao's avatar
Jiayi Yao committed
129

130
131
132
133
134
        # At this moment, we assume all eagle layers belong to the same KV
        # cache group, thus using the same attention metadata.
        per_layer_attn_metadata = {}
        for layer_name in self.attn_layer_names:
            per_layer_attn_metadata[layer_name] = attn_metadata
135
136
137
138
139
140
141
        if self.use_cuda_graph and \
            num_tokens <= self.cudagraph_batch_sizes[-1]:
            num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
        else:
            num_input_tokens = num_tokens
        # copy inputs to buffer for cudagraph
        self.positions[:num_tokens] = target_positions
142
        self.hidden_states[:num_tokens] = target_hidden_states
143
144
145
146
        
        if (self.use_full_cuda_graph
                and num_tokens <= self.cudagraph_batch_sizes[-1]):
            assert self.attn_metadata_cudagraph
147
            if self.method == "deepseek_mtp":
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
                self.attn_metadata_cudagraph.num_actual_tokens = (
                    attn_metadata.num_actual_tokens)
                self.attn_metadata_cudagraph.query_start_loc[:batch_size + 1] = (
                    attn_metadata.query_start_loc)
                self.attn_metadata_cudagraph.slot_mapping[:num_tokens] = (
                    attn_metadata.slot_mapping)
                self.attn_metadata_cudagraph.num_decodes = (
                    attn_metadata.num_decodes)
                self.attn_metadata_cudagraph.num_decode_tokens = (
                    attn_metadata.num_decode_tokens)
                self.attn_metadata_cudagraph.num_prefills = (
                    attn_metadata.num_prefills)
                
                if attn_metadata.decode is not None:
                    self.attn_metadata_cudagraph.decode.block_table[:attn_metadata.num_decode_tokens] = (
                        attn_metadata.decode.block_table)
                    self.attn_metadata_cudagraph.decode.seq_lens[:attn_metadata.num_decode_tokens] = (
                        attn_metadata.decode.seq_lens)
166

167
        with set_forward_context(per_layer_attn_metadata,
168
169
                                 self.vllm_config,
                                 num_tokens=num_input_tokens):
Jiayi Yao's avatar
Jiayi Yao committed
170
171
172
173
            ret_hidden_states = self.model(
                self.input_ids[:num_input_tokens],
                self.positions[:num_input_tokens],
                self.hidden_states[:num_input_tokens],
174
            )
Jiayi Yao's avatar
Jiayi Yao committed
175
176
177
178
            if self.method == "deepseek_mtp":
                last_hidden_states = ret_hidden_states
            else:
                last_hidden_states, hidden_states = ret_hidden_states
179
        sample_hidden_states = last_hidden_states[last_token_indices]
180
        logits = self.model.compute_logits(sample_hidden_states, None)
181

182
        draft_token_ids = logits.argmax(dim=-1)
183
184
185

        # Early exit if there is only one draft token to be generated.
        if self.num_speculative_tokens == 1:
186
            # [batch_size, 1]
187
            return draft_token_ids.view(-1, 1)
188

Jiayi Yao's avatar
Jiayi Yao committed
189
190
191
192
        # TODO: Currently, MTP module released by deepseek only has
        # one layer. Adapt this code to support multiple layers once
        # there's a multi-layer MTP module.

193
194
195
196
197
        # Currently FlashAttention is the only backend that supports
        # multi-token eagle spec decode. This is because the code below
        # makes assumptions about attn_metadata attributes available.
        assert isinstance(attn_metadata, FlashAttentionMetadata)

198
199
200
201
        # Generate the remaining draft tokens.
        draft_token_ids_list = [draft_token_ids]

        positions = target_positions[last_token_indices]
202
203
204
205
206
207

        if self.method == "deepseek_mtp":
            hidden_states = last_hidden_states[last_token_indices]
        else:
            hidden_states = hidden_states[last_token_indices]

208
209
210
211
212
        if self.use_cuda_graph and \
            batch_size <= self.cudagraph_batch_sizes[-1]:
            input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size)
        else:
            input_batch_size = batch_size
213
214
        attn_metadata.num_actual_tokens = batch_size
        attn_metadata.max_query_len = 1
215
        attn_metadata.query_start_loc = self.arange[:batch_size + 1]
216
217
218
219
220
221
222
223

        if isinstance(attn_metadata, MLACommonMetadata):
            attn_metadata.num_decodes = batch_size
            attn_metadata.num_decode_tokens = batch_size
            attn_metadata.num_prefills = 0
            block_table = self.runner.attn_metadata_builders[0].block_table.get_device_tensor()[:batch_size, ...]
            attn_metadata.decode = self.runner.attn_metadata_builders[0]._build_decode(
                block_table_tensor=block_table,
224
                seq_lens=seq_lens,
225
226
            )

227
        for i in range(self.num_speculative_tokens - 1):
228
            # Update the inputs.
229
230
231
            # cast to int32 is crucial when eagle model is compiled.
            # tensor.argmax() returns int64 by default.
            input_ids = draft_token_ids_list[-1].int()
232
            positions += 1
233
234
235
236
237
238
239
240
241
242
243
244
245

            # NOTE(woosuk): We should handle the case where the draft model
            # generates tokens beyond the max model length. Since it is complex
            # to remove such requests from the batch, we keep them in the batch
            # but adjust the position ids and slot mappings to avoid the
            # out-of-range access during the model execution. The draft tokens
            # generated with this adjustment should be ignored.
            exceeds_max_model_len = positions >= self.max_model_len
            # Mask out the position ids that exceed the max model length.
            # Otherwise, we may get out-of-range error in RoPE.
            clamped_positions = torch.where(exceeds_max_model_len, 0,
                                            positions)

246
247
248
249
250
251
252
253
254
255
256
257
258
259
            if isinstance(attn_metadata, MLACommonMetadata):
                attn_metadata.decode.seq_lens += 1
            else:
                attn_metadata.seq_lens += 1

                # Increment the sequence lengths.
                attn_metadata.max_seq_len += 1
                # Consider max model length.
                attn_metadata.max_seq_len = min(attn_metadata.max_seq_len,
                                                self.max_model_len)
                    
                # For the requests that exceed the max model length, we set the
                # sequence length to 1 to minimize their overheads in attention.
                attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1)
260

261
            # Compute the slot mapping.
262
            block_numbers = clamped_positions // self.block_size
263
264
            block_ids = attn_metadata.block_table.gather(
                dim=1, index=block_numbers.view(-1, 1))
265
266
            block_ids = block_ids.view(-1)
            attn_metadata.slot_mapping = (block_ids * self.block_size +
267
                                        clamped_positions % self.block_size)
268
269
270
271
272
            # Mask out the slot mappings that exceed the max model length.
            # Otherwise, the KV cache will be inadvertently updated with the
            # padding tokens.
            attn_metadata.slot_mapping.masked_fill_(exceeds_max_model_len,
                                                    PADDING_SLOT_ID)
273

274
275
276
            # copy inputs to buffer for cudagraph
            self.input_ids[:batch_size] = input_ids
            self.positions[:batch_size] = clamped_positions
277
            self.hidden_states[:batch_size] = hidden_states
278
279
280
281
            
            if (self.use_full_cuda_graph
                    and batch_size <= self.cudagraph_batch_sizes[-1]):
                assert self.attn_metadata_cudagraph
282
                if self.method == "deepseek_mtp":
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
                    self.attn_metadata_cudagraph.num_actual_tokens = (
                        attn_metadata.num_actual_tokens)
                    self.attn_metadata_cudagraph.slot_mapping[:attn_metadata.num_decode_tokens] = (
                        attn_metadata.slot_mapping)
                    self.attn_metadata_cudagraph.num_decodes = (
                        attn_metadata.num_decodes)
                    self.attn_metadata_cudagraph.num_decode_tokens = (
                        attn_metadata.num_decode_tokens)
                    self.attn_metadata_cudagraph.num_prefills = (
                        attn_metadata.num_prefills)
                    self.attn_metadata_cudagraph.decode.seq_lens[:attn_metadata.num_decode_tokens] = (
                        attn_metadata.decode.seq_lens)
                    
                    if i == 0:
                        self.attn_metadata_cudagraph.query_start_loc[:batch_size + 1] = (
                            attn_metadata.query_start_loc)
                        self.attn_metadata_cudagraph.decode.block_table[:attn_metadata.num_decode_tokens] = (
                            attn_metadata.decode.block_table)
301

302
            # Run the model.
303
            with set_forward_context(per_layer_attn_metadata,
304
305
                                     self.vllm_config,
                                     num_tokens=input_batch_size):
306
                ret_hidden_states = self.model(
Jiayi Yao's avatar
Jiayi Yao committed
307
308
309
                    self.input_ids[:input_batch_size],
                    self.positions[:input_batch_size],
                    self.hidden_states[:input_batch_size],
310
                )
311
312
313
314
315
316
317
                if self.method == "deepseek_mtp":
                    last_hidden_states = ret_hidden_states
                    hidden_states = last_hidden_states[:batch_size]
                else:
                    last_hidden_states, hidden_states = ret_hidden_states
                    hidden_states = hidden_states[:batch_size]

318
319
            logits = self.model.compute_logits(last_hidden_states[:batch_size],
                                               None)
320
321

            # TODO(wenlong): get more than one token for tree attention
322
            draft_token_ids = logits.argmax(dim=-1)
323
324
325
326
            draft_token_ids_list.append(draft_token_ids)

        # [batch_size, num_speculative_tokens]
        draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
327

328
        return draft_token_ids
329
330

    def prepare_inputs(
331
332
        self,
        common_attn_metadata: CommonAttentionMetadata,
333
        # [batch_size]
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
        num_rejected_tokens: torch.Tensor
    ) -> tuple[CommonAttentionMetadata, torch.Tensor]:
        """
        This function is used to prepare the inputs for the spec decode.
        It updates to the common_attn_metadata to account for the rejected
        tokens (and newly sampled tokens). It also returns the token indices
        of the tokens that should be fed to the speculator.
        """
        # E.g.
        #  common_attn_metadata.query_start_loc{_cpu}:
        #         [0, q1, q1 + q2, q1 + q2 + q3]
        #  common_attn_metadata.seq_lens{_cpu}: [s1, s2, s3]
        #  num_rejected_tokens: [n1, n2, n3]
        # This function computes the intermediate values:
        #  num_tokens_per_req: [q1 - n1, q2 - n2, q3 - n3]
        # And returns:
        #  common_attn_metadata.query_start_loc{_cpu}:
        #         [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
        #  common_attn_metadata.seq_lens{_cpu}:
        #         [s1 - n1 + 1, s2 - n2 + 1, s3 - n3 + 1]
        #  token_indices: [0, 1, ..., q1 - n1 - 1,
        #                  q1, q1 + 1, ..., q1 + q2 - n2 - 1,
        #                  q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1]

        device = common_attn_metadata.query_start_loc.device
        query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
        new_seq_lens_cpu = common_attn_metadata.seq_lens_cpu \
            - num_rejected_tokens

        # [0, q1, q1 + q2, q1 + q2 + q3] -> [q1, q2, q3]
        new_query_len_per_req = (query_start_loc_cpu[1:] -
                                 query_start_loc_cpu[:-1])
        # [q1, q2, q3] -> [q1 - n1, q2 - n2, q3 - n3]
        new_num_tokens_per_req = new_query_len_per_req - num_rejected_tokens
        new_num_tokens_per_req_np = new_num_tokens_per_req.numpy()

        # [q1 - n1, q2 - n2, q3 - n3] ->
        # [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
        new_query_start_loc_cpu = torch.zeros(
            query_start_loc_cpu.shape,
374
            dtype=torch.int32,
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
            pin_memory=is_pin_memory_available())
        new_query_start_loc_np = new_query_start_loc_cpu.numpy()
        np.cumsum(new_num_tokens_per_req_np, out=new_query_start_loc_np[1:])

        total_num_tokens = new_query_start_loc_np[-1]
        # Example assuming num_tokens_per_req_np = [2, 4, 3]
        # this implies that `new_query_start_locs` is:
        # [0, 2, 6, 9] ->
        # [0, 0, 2, 2, 2, 2, 6, 6, 6]
        #  _r1_  ____r2____  ___r3__
        new_query_start_locs_expanded = np.repeat(new_query_start_loc_np[:-1],
                                                  new_num_tokens_per_req_np)
        # [0, 1, 2, 3, 4, 5, 6, 7, 8] ->
        # [0, 1, 0, 1, 2, 3, 0, 1, 2]
        #  _r1_  ____r2____  ___r3__
        token_offests = self.token_arange_np[:total_num_tokens] \
            - new_query_start_locs_expanded

        # Expand starting positions to match token pattern
        # [0, q1, q1 + q2] ->
        # [0, 0, q1, q1, q1, q1, q1 + q2, q1 + q2, q1 + q2]
        #  _r1_  _____r2_______  ___________r3____________
        old_query_start_locs_expanded = np.repeat(
            query_start_loc_cpu[:-1].numpy(), new_num_tokens_per_req_np)
        # Final token indices are:
        # [0, 1,                                   // req 1
        #  q1 + 0, q1 + 1, q1 + 2, q1 + 3,         // req 2
        #  q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2]  // req 3
        token_indices_np = token_offests + old_query_start_locs_expanded
        token_indices = torch.from_numpy(token_indices_np).to(
            device, non_blocking=True)

        spec_common_attn_metadata = CommonAttentionMetadata(
            query_start_loc=new_query_start_loc_cpu.to(device,
                                                       non_blocking=True),
            seq_lens=new_seq_lens_cpu.to(device, non_blocking=True),
            query_start_loc_cpu=new_query_start_loc_cpu,
            seq_lens_cpu=new_seq_lens_cpu,
            num_computed_tokens_cpu=common_attn_metadata.
            num_computed_tokens_cpu,
            num_reqs=common_attn_metadata.num_reqs,
            num_actual_tokens=total_num_tokens,
            max_query_len=new_query_len_per_req.max().item(),
            block_table_tensor=common_attn_metadata.block_table_tensor,
            slot_mapping=common_attn_metadata.slot_mapping[token_indices],
420
        )
421
422

        return spec_common_attn_metadata, token_indices
423
424

    def load_model(self, target_model: nn.Module) -> None:
425
426
        draft_model_config = \
            self.vllm_config.speculative_config.draft_model_config
427
428
        target_attn_layer_names = set(
            get_layers_from_vllm_config(self.vllm_config, Attention).keys())
429

430
431
432
433
        from vllm.compilation.backends import set_model_tag
        with set_model_tag("eagle_head"):
            self.model = get_model(vllm_config=self.vllm_config,
                                   model_config=draft_model_config)
434

435
436
437
        draft_attn_layer_names = (
            get_layers_from_vllm_config(self.vllm_config, Attention).keys() -
            target_attn_layer_names)
438
439

        self.attn_layer_names = list(draft_attn_layer_names)
440

441
442
443
444
445
446
447
        if supports_multimodal(target_model):
            # handle multimodality
            self.model.config.image_token_index = (
                target_model.config.image_token_index)
            target_language_model = target_model.get_language_model()
        else:
            target_language_model = target_model
448
        # share embed_tokens with the target model if needed
449
        if get_pp_group().world_size == 1 \
zhuwenwen's avatar
zhuwenwen committed
450
            and self.method != "deepseek_mtp" \
451
            and self.model.model.embed_tokens.weight.shape \
452
                == target_language_model.model.embed_tokens.weight.shape:
453
            logger.info(
454
                "Assuming the EAGLE head shares the same vocab embedding" \
455
456
                " with the target model."
            )
457
            del self.model.model.embed_tokens
458
459
            self.model.model.embed_tokens = (
                target_language_model.model.embed_tokens)
460
        else:
461
            logger.info(
462
463
                "The EAGLE head's vocab embedding will be loaded separately" \
                " from the target model."
464
465
466
467
468
469
            )

        # share lm_head with the target model if needed
        # some model definition do not define lm_head explicitly
        # and reuse embed_tokens for lm_head, e.g., CohereForCausalLM
        if self.vllm_config.speculative_config.method != "eagle3" and \
470
                hasattr(target_language_model, "lm_head"):
471
            logger.info("Loading EAGLE LM head weights from the target model.")
472
            self.model.lm_head = target_language_model.lm_head
473

474
475
476
477
    @torch.inference_mode()
    def dummy_run(
        self,
        num_tokens: int,
478
        attn_metadata: Optional[dict[str, Any]] = None,
479
    ) -> None:
480
481
482
483
484
        if attn_metadata is not None and self.attn_metadata_cudagraph is None:
            self.attn_metadata_cudagraph = attn_metadata[
                self.attn_layer_names[0]]
        with set_forward_context(attn_metadata,
                                 self.vllm_config,
485
                                 num_tokens=num_tokens):
486
            self.model(
Jiayi Yao's avatar
Jiayi Yao committed
487
488
489
                self.input_ids[:num_tokens],
                self.positions[:num_tokens],
                self.hidden_states[:num_tokens],
490
            )
491

492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
    def validate_same_kv_cache_group(self,
                                     kv_cache_config: KVCacheConfig) -> None:
        """
        Validate that all eagle layers belong to the same KVCacheGroup.
        Need this assumption to ensure all eagle layers can use the
        same AttentionMetadata.
        May extend to multiple AttentionMetadata in the future.
        """
        kv_cache_groups: dict[str, int] = {}
        for id, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups):
            for layer_name in kv_cache_group.layer_names:
                kv_cache_groups[layer_name] = id
        assert len(
            set([
                kv_cache_groups[layer_name]
                for layer_name in self.attn_layer_names
            ])
        ) == 1, "All eagle layers should belong to the same kv cache group"
510
511


512
513
514
515
# NOTE(woosuk): Currently, the below code is not used and we always use argmax
# to sample the draft tokens. We will use this after we find a way to manage
# the draft prob tensor.
# Refer to https://github.com/vllm-project/vllm/pull/16899 for the details.
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
# FIXME(woosuk): The logic here is duplicated with the main sampling code.
# We should refactor this to reuse the same sampling implementation.
def compute_probs_and_sample_next_token(
    logits: torch.Tensor,
    sampling_metadata: SamplingMetadata,
) -> tuple[torch.Tensor, torch.Tensor]:
    if sampling_metadata.all_greedy:
        # For greedy requests, draft_probs is not used in rejection sampling.
        # Therefore, we can just return the logits.
        probs = logits
        next_token_ids = logits.argmax(dim=-1)
        return next_token_ids, probs

    is_greedy = sampling_metadata.temperature == -1
    temperature = torch.where(is_greedy, 1.0, sampling_metadata.temperature)
    logits.div_(temperature.view(-1, 1))
    probs = logits.softmax(dim=-1, dtype=torch.float32)

    # NOTE(woosuk): Currently, we ignore most of the sampling parameters in
    # generating the draft tokens. We only use the temperature. While this
    # could degrade the acceptance rate, it does not affect the distribution
    # of the generated tokens after rejection sampling.

    # TODO(woosuk): Consider seeds.
    q = torch.empty_like(probs)
    q.exponential_()
542
543
544
    # NOTE(woosuk): We shouldn't use `probs.div_(q)` because the draft_probs
    # will be used later for rejection sampling.
    next_token_ids = probs.div(q).argmax(dim=-1).view(-1)
545
546
547
548
549
550
551
552
    if not sampling_metadata.all_random:
        greedy_token_ids = probs.argmax(dim=-1)
        next_token_ids = torch.where(
            is_greedy,
            greedy_token_ids,
            next_token_ids,
        )
    return next_token_ids, probs