gpu_model_runner.py 39.2 KB
Newer Older
1
from typing import Any, Optional, Union
lizhigong's avatar
lizhigong committed
2
3
import torch
import numpy as np
4
from vllm import envs
5
from vllm.distributed.kv_transfer.kv_transfer_state import get_kv_transfer_group, has_kv_transfer_group
6
7
8
9
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
from vllm.forward_context import set_forward_context
from vllm.sequence import IntermediateTensors
from vllm.utils import async_tensor_h2d, round_up
lizhigong's avatar
lizhigong committed
10
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
11
12
13
14
15
16
17
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.eagle import EagleProposer
from vllm.v1.spec_decode.medusa import MedusaProposer
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
lizhigong's avatar
lizhigong committed
18
from vllm.v1.worker.block_table import BlockTable
19
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
20
from vllm.zero_overhead.v1.eagle import V1ZeroEagleProposer
lizhigong's avatar
lizhigong committed
21
22
from vllm.zero_overhead.v1.outputs import ZeroV1ModelRunnerOutput
from vllm.profiler.prof import profile
23
from vllm.two_batch_overlap.v1.model_input_split_v1 import tbo_split_and_execute_model
王敏's avatar
王敏 committed
24
from vllm.v1.spec_decode.utils import DraftProbs
lizhigong's avatar
lizhigong committed
25

26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import triton
import triton.language as tl

@triton.jit
def fused_last_valid_scatter_kernel(
    last_ids_ptr,        # [B, T]
    input_ids_ptr,       # [N]
    update_req_ptr,      # [U]
    input_pos_ptr,       # [U]
    stride0,
    stride1,
    T,
    BLOCK_T: tl.constexpr,
):
    pid = tl.program_id(0)

    # indices
    req_idx = tl.load(update_req_ptr + pid)
    input_pos = tl.load(input_pos_ptr + pid)

    # load row
    offs = tl.arange(0, BLOCK_T)
    mask = offs < T

    row_ptr = last_ids_ptr + req_idx * stride0 + offs * stride1
    vals = tl.load(row_ptr, mask=mask, other=-1)

    # ✅ 正确做法:index reduction
    idx = tl.where(vals != -1, offs, -1)
    last_idx = tl.max(idx, axis=0)

    # load last token
    last_val = tl.load(
        last_ids_ptr + req_idx * stride0 + last_idx * stride1,
        mask=last_idx >= 0,
        other=0,
    )

    # scatter
    tl.store(input_ids_ptr + input_pos, last_val)
lizhigong's avatar
lizhigong committed
66

67
68
69
class V1ZeroModelRunner(GPUModelRunner):
    def __init__(self, vllm_config, device):
        super().__init__(vllm_config, device)
lizhigong's avatar
lizhigong committed
70
71
72
73
74
        self.last_sampled_token_ids = None
        self.last_sampled_req_ids = []
        self.last_sampled_token_lens = []
        self.last_sampler_event = torch.cuda.Event(enable_timing=False)
        self.last_sampler_host_tokens = None
75
        self.token_ids_cpu_fix_record = []
76
77
78
        self.last_draft_token_ids = None
        self.last_draft_host_tokens = None
        self.last_draft_event = torch.cuda.Event(enable_timing=False)
79
80
        self.spec_sampler_event = torch.cuda.Event(enable_timing=False)
        self.spec_scheduler_max_num_tokens = 0
jujl1's avatar
jujl1 committed
81
82
83
        self.fix_req_ids = None
        self.fix_sampled_token_ids = None

lizhigong's avatar
lizhigong committed
84
85
86
        if hasattr(self, 'drafter') and isinstance(self.drafter, EagleProposer):
            self.drafter = V1ZeroEagleProposer(self.vllm_config, self.device,
                                                self) 
lizhigong's avatar
lizhigong committed
87
    
lizhigong's avatar
lizhigong committed
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
    def _prepare_inputs(
        self,
        scheduler_output: "SchedulerOutput",
    ) -> tuple[dict[str, Any], bool, torch.Tensor,
               Optional[SpecDecodeMetadata], np.ndarray]:
        """
        :return: tuple[
            attn_metadata: layer-to-attention_metadata mapping,
            attention_cuda_graphs: whether attention can run in cudagraph
            logits_indices, spec_decode_metadata
        ]
        """
        total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
        assert total_num_scheduled_tokens > 0
        num_reqs = self.input_batch.num_reqs
        assert num_reqs > 0

        # OPTIMIZATION: Start copying the block table first.
        # This way, we can overlap the copy with the following CPU operations.
        self.input_batch.block_table.commit(num_reqs)

        # Get the number of scheduled tokens for each request.
        req_ids = self.input_batch.req_ids
        tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids]
        num_scheduled_tokens = np.array(tokens, dtype=np.int32)
        max_num_scheduled_tokens = max(tokens)
114
        self.spec_scheduler_max_num_tokens = max_num_scheduled_tokens
lizhigong's avatar
lizhigong committed
115
116
117
118
119
120
121
122
123
124
125

        # Get request indices.
        # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
        req_indices = np.repeat(self.arange_np[:num_reqs],
                                num_scheduled_tokens)

        # cu_num_tokens: [2, 5, 3] -> [2, 7, 10]
        # arange: [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
        cu_num_tokens, arange = self._get_cumsum_and_arange(
            num_scheduled_tokens)

jujl1's avatar
jujl1 committed
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
        self.fix_req_ids = self.last_sampled_req_ids
        if self.last_sampler_host_tokens != None:
            self.last_sampler_event.synchronize()  # 等上一轮主模型结束
            if self.speculative_config:  # 处理上一轮mtp
                num_gen_tokens = self.last_sampler_host_tokens.shape[-1]
                if num_gen_tokens == 1:
                    self.fix_sampled_token_ids = self.last_sampler_host_tokens.tolist()
                else:
                    # Includes spec decode tokens.
                    self.fix_sampled_token_ids = self.rejection_sampler.parse_output(
                        self.last_sampler_host_tokens,
                        self.input_batch.vocab_size,
                    )

            for req_idx, start_idx, end_idx in self.token_ids_cpu_fix_record:
                if start_idx == -1:
                    self.fix_sampled_token_ids[req_idx].clear()
                else:
                    num_accepted_tokens = len(self.fix_sampled_token_ids[req_idx])
                    req_id = self.fix_req_ids[req_idx]
                    if req_id in self.input_batch.req_ids:
                        new_req_idx = self.input_batch.req_ids.index(req_id)
                        new_end_idx = start_idx + num_accepted_tokens
                        # # 更新token统计数据
                        self.input_batch.num_tokens_no_spec[new_req_idx] = new_end_idx
                        self.input_batch.num_tokens[new_req_idx] = new_end_idx
                        self.input_batch.token_ids_cpu[new_req_idx, start_idx:new_end_idx] = self.fix_sampled_token_ids[
                            req_idx]
                        self.input_batch.num_computed_tokens_cpu[new_req_idx] -= (end_idx - new_end_idx)
                    if req_id in self.requests:
                        req_state = self.requests[req_id]
                        req_state.output_token_ids.extend(self.fix_sampled_token_ids[req_idx])

lizhigong's avatar
lizhigong committed
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
        # Get positions.
        positions_np = self.positions_np[:total_num_scheduled_tokens]
        np.add(self.input_batch.num_computed_tokens_cpu[req_indices],
               arange,
               out=positions_np)

        # Calculate M-RoPE positions.
        # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
        if self.uses_mrope:
            self._calc_mrope_positions(scheduler_output)

        # Get token indices.
        # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
        # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2]
        # where M is the max_model_len.
        token_indices = (positions_np +
                         req_indices * self.input_batch.token_ids_cpu.shape[1])

        # NOTE(woosuk): We use torch.index_select instead of np.take here
        # because torch.index_select is much faster than np.take for large
        # tensors.
        torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(),
                           0,
                           torch.from_numpy(token_indices),
                           out=self.input_ids_cpu[:total_num_scheduled_tokens])

        # Calculate the slot mapping for each KV cache group.
        for kv_cache_group_id, kv_cache_group_spec in enumerate(
                self.kv_cache_config.kv_cache_groups):
            block_size = kv_cache_group_spec.kv_cache_spec.block_size
            block_table: BlockTable = self.input_batch.block_table[
                kv_cache_group_id]
            # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
            # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1]
            # where K is the max_num_blocks_per_req and the block size is 2.
            # NOTE(woosuk): We can't simply use `token_indices // block_size`
            # here because M (max_model_len) is not necessarily divisible by
            # block_size.
            block_table_indices = (
                req_indices * block_table.max_num_blocks_per_req +
                positions_np // block_size)
            block_table_cpu = block_table.get_cpu_tensor()
            block_numbers = block_table_cpu.flatten(
            )[block_table_indices].numpy()
            block_offsets = positions_np % block_size
            np.add(
                block_numbers * block_size,
                block_offsets,
                out=block_table.slot_mapping_np[:total_num_scheduled_tokens])

        # Prepare the attention metadata.
        self.query_start_loc_np[0] = 0
        self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens

        self.seq_lens_np[:num_reqs] = (
            self.input_batch.num_computed_tokens_cpu[:num_reqs] +
            num_scheduled_tokens)

        # Copy the tensors to the GPU.
        self.input_ids[:total_num_scheduled_tokens].copy_(
            self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True)
        
        self.zero_prepare_inputs(scheduler_output, self.input_ids)

        if self.uses_mrope:
            # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
            self.mrope_positions[:, :total_num_scheduled_tokens].copy_(
                self.mrope_positions_cpu[:, :total_num_scheduled_tokens],
                non_blocking=True)
        else:
            # Common case (1D positions)
            self.positions[:total_num_scheduled_tokens].copy_(
                self.positions_cpu[:total_num_scheduled_tokens],
                non_blocking=True)

        self.query_start_loc[:num_reqs + 1].copy_(
            self.query_start_loc_cpu[:num_reqs + 1], non_blocking=True)
        self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs],
                                       non_blocking=True)

        # Fill unused with -1. Needed for reshape_and_cache
        self.seq_lens[num_reqs:].fill_(0)
        # Note: pad query_start_loc to be non-decreasing, as kernels
        # like FlashAttention requires that
        self.query_start_loc[num_reqs + 1:].fill_(
            self.query_start_loc_cpu[num_reqs].item())

        query_start_loc = self.query_start_loc[:num_reqs + 1]
        seq_lens = self.seq_lens[:num_reqs]

        common_attn_metadata = CommonAttentionMetadata(
            query_start_loc=query_start_loc,
            seq_lens=seq_lens,
            num_reqs=num_reqs,
            num_actual_tokens=total_num_scheduled_tokens,
            max_query_len=max_num_scheduled_tokens,
        )

        attn_metadata: dict[str, Any] = {}
        # Prepare the attention metadata for each KV cache group and make layers
        # in the same group share the same metadata.
        for kv_cache_group_id, kv_cache_group_spec in enumerate(
                self.kv_cache_config.kv_cache_groups):

            # Prepare for cascade attention if enabled & beneficial.
            common_prefix_len = 0
            builder = self.attn_metadata_builders[kv_cache_group_id]
            if self.cascade_attn_enabled:
                common_prefix_len = self._compute_cascade_attn_prefix_len(
                    num_scheduled_tokens,
                    scheduler_output.
                    num_common_prefix_blocks[kv_cache_group_id],
                    kv_cache_group_spec.kv_cache_spec,
                    builder,
                )

            attn_metadata_i = (builder.build(
                common_prefix_len=common_prefix_len,
                common_attn_metadata=common_attn_metadata,
            ))

            for layer_name in kv_cache_group_spec.layer_names:
                attn_metadata[layer_name] = attn_metadata_i

        attention_cuda_graphs = all(
            b.can_run_in_cudagraph(common_attn_metadata)
            for b in self.attn_metadata_builders)

        use_spec_decode = len(
            scheduler_output.scheduled_spec_decode_tokens) > 0
        if not use_spec_decode:
            # NOTE(woosuk): Due to chunked prefills, the batch may contain
            # partial requests. While we should not sample any token
            # from these partial requests, we do so for simplicity.
            # We will ignore the sampled tokens from the partial requests.
            # TODO: Support prompt logprobs.
            logits_indices = query_start_loc[1:] - 1
            spec_decode_metadata = None
        else:
            # Get the number of draft tokens for each request.
            # Iterate over the dictionary rather than all requests since not all
            # requests have draft tokens.
            num_draft_tokens = np.zeros(num_reqs, dtype=np.int32)
            for req_id, draft_token_ids in (
                    scheduler_output.scheduled_spec_decode_tokens.items()):
                req_idx = self.input_batch.req_id_to_index[req_id]
                num_draft_tokens[req_idx] = len(draft_token_ids)

王敏's avatar
王敏 committed
307
308
309
310
            spec_decode_ids = None
            if envs.VLLM_REJECT_SAMPLE_OPT:
                spec_decode_ids = scheduler_output.scheduled_spec_decode_tokens.keys()

lizhigong's avatar
lizhigong committed
311
            spec_decode_metadata = self._calc_spec_decode_metadata(
王敏's avatar
王敏 committed
312
                num_draft_tokens, cu_num_tokens, spec_decode_ids)
lizhigong's avatar
lizhigong committed
313
314
315
316
317
318
319
320
321
            logits_indices = spec_decode_metadata.logits_indices

        # Hot-Swap lora model
        if self.lora_config:
            self.set_active_loras(self.input_batch, num_scheduled_tokens)

        return (attn_metadata, attention_cuda_graphs, logits_indices,
                spec_decode_metadata, num_scheduled_tokens)

322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
    def zero_prepare_inputs(self, scheduler_output, input_ids):
        req_ids = self.input_batch.req_ids
        update_req_indices = []
        input_ids_indices = []
        token_idx = 0
        if self.last_draft_token_ids is not None:
            draft_tokens_num = self.last_draft_token_ids.shape[1]
            for req_id in req_ids:
                if req_id in self.last_sampled_req_ids:
                    req_idx = self.last_sampled_req_ids.index(req_id) * draft_tokens_num
                    for num_idx in range(draft_tokens_num):
                        update_req_indices.append(req_idx + num_idx)
                        input_ids_indices.append(token_idx + num_idx + 1)
                token_idx += draft_tokens_num + 1
            if len(update_req_indices) > 0:
                update_req_indices_tensor = async_tensor_h2d(update_req_indices, torch.int32,
                                                            self.device,
                                                            True)
                input_ids_indices_tensor = async_tensor_h2d(input_ids_indices, torch.int32,
                                                            self.device,
                                                            True)
                last_draft_token_ids = self.last_draft_token_ids.flatten().to(torch.int)
                input_ids[input_ids_indices_tensor] = last_draft_token_ids[update_req_indices_tensor]
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368

        def fused_update_input_ids(
                last_sampled_token_ids,
                input_ids,
                update_req_indices,
                input_ids_indices,
        ):
            B, T = last_sampled_token_ids.shape
            U = update_req_indices.numel()

            BLOCK_T = 1024
            assert T <= BLOCK_T

            grid = (U,)
            fused_last_valid_scatter_kernel[grid](
                last_sampled_token_ids,
                input_ids,
                update_req_indices,
                input_ids_indices,
                last_sampled_token_ids.stride(0),
                last_sampled_token_ids.stride(1),
                T,
                BLOCK_T=BLOCK_T,
            )
lizhigong's avatar
lizhigong committed
369

370
371
372
373
374
375
        update_req_indices = []
        input_ids_indices = []
        token_idx = 0
        if self.last_sampled_token_ids is not None:
            for req_id in req_ids:
                if req_id in self.last_sampled_req_ids:
jujl1's avatar
jujl1 committed
376
                    req_idx = self.last_sampled_req_ids.index(req_id)
377
378
379
380
381
382
383
384
385
386
                    update_req_indices.append(req_idx)
                    input_ids_indices.append(token_idx)
                token_idx += scheduler_output.num_scheduled_tokens[req_id]
            if len(update_req_indices) > 0:
                update_req_indices_tensor = async_tensor_h2d(update_req_indices, torch.int32,
                                                            self.device,
                                                            True)
                input_ids_indices_tensor = async_tensor_h2d(input_ids_indices, torch.int32,
                                                            self.device,
                                                            True)
387
388
389
390
391
392
                fused_update_input_ids(
                    self.last_sampled_token_ids,
                    input_ids,
                    update_req_indices_tensor,
                    input_ids_indices_tensor)

jujl1's avatar
jujl1 committed
393

394
395
396
397

    def propose_draft_token_ids(
        self,
        scheduler_output: "SchedulerOutput",
398
399
        num_accepted_tokens_tensor: torch.Tensor,
        sampled_token_ids: torch.Tensor,
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
        sampling_metadata: SamplingMetadata,
        hidden_states: torch.Tensor,
        sample_hidden_states: torch.Tensor,
        aux_hidden_states: Optional[torch.Tensor],
        spec_decode_metadata: Optional[SpecDecodeMetadata],
        attn_metadata: dict[str, Any],
    ) -> list[list[int]]:
        num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
        if self.speculative_config.method == "ngram":
            assert isinstance(self.drafter, NgramProposer)
            spec_token_ids = self.propose_ngram_draft_token_ids(
                sampled_token_ids)
        elif self.speculative_config.method == "medusa":
            assert isinstance(self.drafter, MedusaProposer)
            if sample_hidden_states.shape[0] == len(sampled_token_ids):
                # The input to the target model does not include draft tokens.
                hidden_states = sample_hidden_states
            else:
                indices = []
                offset = 0
                for num_draft, tokens in zip(
                        spec_decode_metadata.num_draft_tokens,
                        sampled_token_ids):
                    indices.append(offset + len(tokens) - 1)
                    offset += num_draft + 1
                indices = torch.tensor(indices, device=self.device)
                hidden_states = sample_hidden_states[indices]

            spec_token_ids = self.drafter.propose(
                target_hidden_states=hidden_states,
                sampling_metadata=sampling_metadata,
            )
        elif self.speculative_config.use_eagle():
            assert isinstance(self.drafter, EagleProposer)
            # TODO(woosuk): Refactor the loop.
435
436
            row_indices = torch.arange(sampled_token_ids.size(0), device=sampled_token_ids.device)
            next_token_ids = sampled_token_ids[row_indices, num_accepted_tokens_tensor].flatten()
437
438
439
440
441
442
443
444
445
446
447
            # At this moment, we assume all eagle layers belong to the same KV
            # cache group, thus using the same attention metadata.
            eagle_attn_metadata = attn_metadata[
                self.drafter.attn_layer_names[0]]

            # NOTE: deepseek_mtp uses MLA which does not have `block_table`
            if hasattr(eagle_attn_metadata, "block_table"):
                block_table = eagle_attn_metadata.block_table
            else:
                block_table = None

448
            spec_scheduler_max_num_tokens = self.spec_scheduler_max_num_tokens
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
            if spec_decode_metadata is None:
                # input_ids can be None for multimodal models.
                target_token_ids = self.input_ids[:num_scheduled_tokens]
                # TODO(woosuk): Support M-RoPE.
                target_positions = self.positions[:num_scheduled_tokens]
                if self.use_aux_hidden_state_outputs:
                    target_hidden_states = torch.cat(
                        [h[:num_scheduled_tokens] for h in aux_hidden_states],
                        dim=-1)
                else:
                    target_hidden_states = hidden_states[:num_scheduled_tokens]
                target_slot_mapping = eagle_attn_metadata.slot_mapping
                cu_num_tokens = eagle_attn_metadata.query_start_loc
            else:
                # TODO(woosuk): Refactor this.
                cu_num_tokens, token_indices = self.drafter.prepare_inputs(
                    eagle_attn_metadata.query_start_loc,
lizhigong's avatar
lizhigong committed
466
                    num_accepted_tokens_tensor,
467
                )
468
                spec_scheduler_max_num_tokens = 1
469
470
471
472
473
474
475
476
477
478
                target_token_ids = self.input_ids[token_indices]
                # TODO(woosuk): Support M-RoPE.
                target_positions = self.positions[token_indices]
                if self.use_aux_hidden_state_outputs:
                    target_hidden_states = torch.cat(
                        [h[token_indices] for h in aux_hidden_states], dim=-1)
                else:
                    target_hidden_states = hidden_states[token_indices]
                target_slot_mapping = eagle_attn_metadata.slot_mapping[
                    token_indices]
479
            self.drafter.spec_scheduler_max_num_tokens = spec_scheduler_max_num_tokens
王敏's avatar
王敏 committed
480
            draft_result = self.drafter.propose(
481
482
483
484
485
486
487
488
                target_token_ids=target_token_ids,
                target_positions=target_positions,
                target_hidden_states=target_hidden_states,
                target_slot_mapping=target_slot_mapping,
                next_token_ids=next_token_ids,
                cu_num_tokens=cu_num_tokens,
                block_table=block_table,
                sampling_metadata=sampling_metadata,
489
                decoding=spec_decode_metadata is not None,
490
            )
王敏's avatar
王敏 committed
491
492
493
494
495
496
497
498
499
500
501
502
503

            if not envs.VLLM_REJECT_SAMPLE_OPT:
                draft_token_ids = draft_result
            else:
                draft_token_ids, draft_probs = draft_result
                draft_req_ids = list(scheduler_output.num_scheduled_tokens.keys())
                
                if self.draft_probs is None:
                    self.draft_probs = DraftProbs(
                        draft_probs, draft_req_ids)
                else:
                    self.draft_probs.update(draft_probs, draft_req_ids)

504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
            spec_token_ids = np.ones(draft_token_ids.shape, dtype=int).tolist()
            self.last_draft_token_ids = draft_token_ids
            self.last_draft_host_tokens = draft_token_ids.to('cpu', non_blocking=True)
            self.last_draft_event.record()
        return spec_token_ids

    @torch.inference_mode()
    def execute_model(
        self,
        scheduler_output: "SchedulerOutput",
        intermediate_tensors: Optional[IntermediateTensors] = None,
    ) -> Union[ModelRunnerOutput, IntermediateTensors]:
        self._update_states(scheduler_output)
        if not scheduler_output.total_num_scheduled_tokens:
            if not has_kv_transfer_group():
                # Return empty ModelRunnerOutput if there's no work to do.
                return EMPTY_MODEL_RUNNER_OUTPUT

            return self.kv_connector_no_forward(scheduler_output)

        # Prepare the decoder inputs.
        (attn_metadata, attention_cuda_graphs, logits_indices,
         spec_decode_metadata,
         num_scheduled_tokens_np) = (self._prepare_inputs(scheduler_output))
        num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
529
530
531
532
533
534
535
536
537
538
        
        # make sure that the padded length is divisible by attn_tp_size because we may need reduce-scatter across attn_tp dim.
        if self.ep_sp:
            num_input_tokens = round_up(num_scheduled_tokens, tp_size)
            if (self.use_cuda_graph
                    and num_input_tokens <= self.cudagraph_batch_sizes[-1]):
                # Use piecewise CUDA graphs.
                # Add padding to the batch size.
                num_input_tokens = self.vllm_config.pad_for_cudagraph(
                    num_input_tokens)
539
        else:
540
541
542
543
544
545
            if (self.use_cuda_graph
                    and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]):
                # Use piecewise CUDA graphs.
                # Add padding to the batch size.
                num_input_tokens = self.vllm_config.pad_for_cudagraph(
                    num_scheduled_tokens)
546
            else:
547
548
549
550
551
552
553
554
555
                # Eager mode.
                # Pad tokens to multiple of tensor_parallel_size when
                # enabled collective fusion for SP
                tp_size = self.vllm_config.parallel_config.tensor_parallel_size
                if self.compilation_config.pass_config. \
                        enable_sequence_parallelism and tp_size > 1:
                    num_input_tokens = round_up(num_scheduled_tokens, tp_size)
                else:
                    num_input_tokens = num_scheduled_tokens
王敏's avatar
王敏 committed
556

557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
        # Padding for DP
        num_pad, num_tokens_across_dp = self.get_dp_padding(num_input_tokens)
        num_input_tokens += num_pad

        # _prepare_inputs may reorder the batch, so we must gather multi
        # modal outputs after that to ensure the correct order
        if self.is_multimodal_model:
            # Run the multimodal encoder if any.
            self._execute_mm_encoder(scheduler_output)
            mm_embeds = self._gather_mm_embeddings(scheduler_output)
        else:
            mm_embeds = []

        if self.is_multimodal_model and get_pp_group().is_first_rank:
            # NOTE(woosuk): To unify token ids and soft tokens (vision
            # embeddings), we always use embeddings (rather than token ids)
            # as input to the multimodal model, even when the input is text.
            input_ids = self.input_ids[:num_scheduled_tokens]
            if mm_embeds:
                inputs_embeds = self.model.get_input_embeddings(
                    input_ids, mm_embeds)
            else:
                inputs_embeds = self.model.get_input_embeddings(input_ids)
            # TODO(woosuk): Avoid the copy. Optimize.
            self.inputs_embeds[:num_scheduled_tokens].copy_(inputs_embeds)
            inputs_embeds = self.inputs_embeds[:num_input_tokens]
            input_ids = None
        else:
            # For text-only models, we use token ids as input.
            # While it is possible to use embeddings as input just like the
            # multimodal models, it is not desirable for performance since
            # then the embedding layer is not included in the CUDA graph.
            input_ids = self.input_ids[:num_input_tokens]
            inputs_embeds = None
        if self.uses_mrope:
            positions = self.mrope_positions[:, :num_input_tokens]
        else:
            positions = self.positions[:num_input_tokens]

        if get_pp_group().is_first_rank:
            intermediate_tensors = None
        else:
            intermediate_tensors = self.sync_and_slice_intermediate_tensors(
                num_input_tokens, intermediate_tensors, True)

        # Some attention backends only support CUDA Graphs in pure decode.
        # If attention doesn't support CUDA Graphs for this batch, but we
        # compiled with full CUDA graphs, we have to skip them entirely.
        skip_cuda_graphs = self.full_cuda_graph and not attention_cuda_graphs
606
        if envs.VLLM_ENABLE_TBO and (not self.use_cuda_graph or skip_cuda_graphs):
607
608
609
            model_output, finished_sending, finished_recving = \
                 tbo_split_and_execute_model(self, attn_metadata, num_input_tokens,
                                             num_tokens_across_dp, input_ids, positions,
610
611
                                             inputs_embeds, scheduler_output, intermediate_tensors, 
                                             skip_cuda_graphs)
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
        else:
            # Run the model.
            # Use persistent buffers for CUDA graphs.
            with set_forward_context(
                    attn_metadata,
                    self.vllm_config,
                    num_tokens=num_input_tokens,
                    num_tokens_across_dp=num_tokens_across_dp,
                    skip_cuda_graphs=skip_cuda_graphs,
            ):
                self.maybe_setup_kv_connector(scheduler_output)

                model_output = self.model(
                    input_ids=input_ids,
                    positions=positions,
                    intermediate_tensors=intermediate_tensors,
                    inputs_embeds=inputs_embeds,
                )

                self.maybe_wait_for_kv_save()
                finished_sending, finished_recving = (
                    self.get_finished_kv_transfers(scheduler_output))
        if self.use_aux_hidden_state_outputs:
            hidden_states, aux_hidden_states = model_output
        else:
            hidden_states = model_output
            aux_hidden_states = None

        # Broadcast PP output for external_launcher (torchrun)
        # to make sure we are synced across pp ranks
        # TODO: Support overlapping mirco-batches
        # https://github.com/vllm-project/vllm/issues/18019
        broadcast_pp_output = \
            self.parallel_config.distributed_executor_backend \
            == "external_launcher" and len(get_pp_group().ranks) > 0
        if not get_pp_group().is_last_rank:
            # For mid-pipeline stages, return the hidden states.
            if not broadcast_pp_output:
                return hidden_states
            assert isinstance(hidden_states, IntermediateTensors)
            get_pp_group().send_tensor_dict(hidden_states.tensors,
                                            all_gather_group=get_tp_group())
            logits = None
        else:
            if self.input_batch.pooling_params:
                return self._pool(hidden_states, num_scheduled_tokens,
                                  num_scheduled_tokens_np, finished_sending,
                                  finished_recving)

            sample_hidden_states = hidden_states[logits_indices]
            logits = self.model.compute_logits(sample_hidden_states, None)
        if broadcast_pp_output:
            model_output_broadcast_data = {
                "logits": logits.contiguous(),
            } if logits is not None else {}
            model_output_broadcast_data = get_pp_group().broadcast_tensor_dict(
                model_output_broadcast_data, src=len(get_pp_group().ranks) - 1)
            assert model_output_broadcast_data is not None
            logits = model_output_broadcast_data["logits"]

        # Apply structured output bitmasks if present
        if scheduler_output.grammar_bitmask is not None:
            self.apply_grammar_bitmask(scheduler_output, logits)

        # Sample the next token and get logprobs if needed.
        sampling_metadata = self.input_batch.sampling_metadata
        if spec_decode_metadata is None:
            sampler_output = self.sampler(
                logits=logits,
                sampling_metadata=sampling_metadata,
            )
        else:
            # When indexing with a tensor (bonus_logits_indices), PyTorch
            # creates a new tensor with separate storage from the original
            # logits tensor. This means any in-place operations on bonus_logits
            # won't affect the original logits tensor.
            assert logits is not None
王敏's avatar
王敏 committed
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
            if not envs.VLLM_REJECT_SAMPLE_OPT:
                bonus_logits = logits[spec_decode_metadata.bonus_logits_indices]
                sampler_output = self.sampler(
                    logits=bonus_logits,
                    sampling_metadata=sampling_metadata,
                )
                bonus_token_ids = sampler_output.sampled_token_ids

                # Just like `bonus_logits`, `target_logits` is a new tensor with
                # separate storage from the original `logits` tensor. Therefore,
                # it is safe to update `target_logits` in place.
                target_logits = logits[spec_decode_metadata.target_logits_indices]
                output_token_ids = self.rejection_sampler(
                    spec_decode_metadata,
                    None,  # draft_probs
                    target_logits,
                    bonus_token_ids,
                    sampling_metadata,
                )
                sampler_output.sampled_token_ids = output_token_ids
            else:
                # sampling_metadata.all_greedy = True
                # sampling_metadata.all_random = False
                sampler_output = self.sampler(
                    logits=logits,
                    sampling_metadata=sampling_metadata,
                )
                target_token_ids = sampler_output.sampled_token_ids[spec_decode_metadata.target_logits_indices]
                target_logits = logits[spec_decode_metadata.target_logits_indices]

                bonus_token_ids = sampler_output.sampled_token_ids[spec_decode_metadata.bonus_logits_indices]

                output_token_ids = self.rejection_sampler(
                    spec_decode_metadata,
                    self.draft_probs.get_probs(spec_decode_metadata.spec_decode_ids),
                    target_logits,
                    target_token_ids,
                    bonus_token_ids,
                    sampling_metadata,
                )
                sampler_output.sampled_token_ids = output_token_ids
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761

        num_nans_in_logits = {}
        if envs.VLLM_COMPUTE_NANS_IN_LOGITS:
            num_nans_in_logits = self._get_nans_in_logits(logits)

        # TODO(woosuk): The following loop can be slow since it iterates over
        # the requests one by one. Optimize.
        discard_sampled_tokens_req_indices = []
        for i, req_id in enumerate(self.input_batch.req_ids):
            req_state = self.requests[req_id]
            seq_len = (req_state.num_computed_tokens +
                       scheduler_output.num_scheduled_tokens[req_id])
            if seq_len < req_state.num_tokens:
                # Ignore the sampled token for partial prefills.
                # Rewind the generator state as if the token was not sampled.
                # This relies on cuda-specific torch-internal impl details
                generator = self.input_batch.generators.get(i)
                if generator is not None:
                    generator.set_offset(generator.get_offset() - 4)
                # Record the index of the request that should not be sampled,
                # so that we could clear the sampled tokens before returning.
                discard_sampled_tokens_req_indices.append(i)

        # NOTE: GPU -> CPU Sync happens here.
        # Move as many CPU operations as possible before this sync point.
        logprobs_tensors = sampler_output.logprobs_tensors
        logprobs_lists = logprobs_tensors.tolists() \
            if logprobs_tensors is not None else None

        # Compute prompt logprobs if needed.
        prompt_logprobs_dict = self._get_prompt_logprobs_dict(
            hidden_states[:num_scheduled_tokens],
lizhigong's avatar
lizhigong committed
762
763
764
            scheduler_output,
        )

765
766
767
        fix_draft_token_ids = None
        fix_draft_req_ids = self.last_sampled_req_ids
        is_output_valid = False
768
769
770
        # Get the valid generated tokens.
        sampled_token_ids = sampler_output.sampled_token_ids

jujl1's avatar
jujl1 committed
771
772
773
774
775
776
        self.last_sampler_host_tokens = None
        self.last_sampled_token_ids = None
        self.last_sampler_host_tokens = sampled_token_ids.to('cpu', non_blocking=True)
        self.last_sampler_event.record()
        self.last_sampled_token_ids = sampled_token_ids
        valid_sampled_token_ids = np.ones(sampled_token_ids.shape, dtype=int).tolist()
777

778
779
780
781
782
783
784
785
786
        # Mask out the sampled tokens that should not be sampled.
        for i in discard_sampled_tokens_req_indices:
            valid_sampled_token_ids[i].clear()

        # Cache the sampled tokens in the model runner, so that the scheduler
        # doesn't need to send them back.
        # NOTE(woosuk): As an exception, when using PP, the scheduler sends
        # the sampled tokens back, because there's no direct communication
        # between the first-stage worker and the last-stage worker.
787
        self.token_ids_cpu_fix_record.clear()
788
789
790
        self.last_sampled_req_ids = []
        self.last_sampled_token_lens = []
        for req_idx, sampled_ids in enumerate(valid_sampled_token_ids):
791
792
793
            req_id = self.input_batch.req_ids[req_idx]
            self.last_sampled_req_ids.append(req_id)
            cache_output_len = -1
794
            if not sampled_ids:
795
796
                self.last_sampled_token_lens.append(-1)
                self.token_ids_cpu_fix_record.append([req_idx, -1, -1])
797
798
799
800
801
802
803
804
805
806
807
                continue

            start_idx = self.input_batch.num_tokens_no_spec[req_idx]
            end_idx = start_idx + len(sampled_ids)
            assert end_idx <= self.max_model_len, (
                "Sampled token IDs exceed the max model length. "
                f"Total number of tokens: {end_idx} > max_model_len: "
                f"{self.max_model_len}")

            self.input_batch.token_ids_cpu[req_idx,
                                            start_idx:end_idx] = sampled_ids
808
            self.token_ids_cpu_fix_record.append([req_idx, start_idx, end_idx])
809
810
            self.input_batch.num_tokens_no_spec[req_idx] = end_idx
            self.input_batch.num_tokens[req_idx] = end_idx
jujl1's avatar
jujl1 committed
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836

        if not self.speculative_config:
            # Speculative decoding is not enabled.
            spec_token_ids = None
            fix_draft_req_ids = None
        else:
            sampled_token_ids_cpu = sampled_token_ids.to('cpu', non_blocking=True)
            if self.last_draft_host_tokens is not None:
                self.last_draft_event.synchronize()
                fix_draft_token_ids = self.last_draft_host_tokens.tolist()

            mask = (sampled_token_ids == -1)
            mask_int = mask.int()
            first_neg_one_indices = torch.argmax(mask_int, dim=1)
            num_accepted_tokens_tensor = torch.where(torch.any(mask, dim=1), first_neg_one_indices, sampled_token_ids.size(1)) - 1
            spec_token_ids = self.propose_draft_token_ids(
                scheduler_output,
                num_accepted_tokens_tensor,
                sampled_token_ids,
                sampling_metadata,
                hidden_states,
                sample_hidden_states,
                aux_hidden_states,
                spec_decode_metadata,
                attn_metadata,
            )
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854

        # Clear KVConnector state after all KVs are generated.
        if has_kv_transfer_group():
            get_kv_transfer_group().clear_connector_metadata()

        self.eplb_step()

        model_output = ZeroV1ModelRunnerOutput(
            req_ids=self.input_batch.req_ids,
            req_id_to_index=self.input_batch.req_id_to_index,
            sampled_token_ids=valid_sampled_token_ids,
            spec_token_ids=spec_token_ids,
            logprobs=logprobs_lists,
            prompt_logprobs_dict=prompt_logprobs_dict,
            pooler_output=[],
            finished_sending=finished_sending,
            finished_recving=finished_recving,
            num_nans_in_logits=num_nans_in_logits,
jujl1's avatar
jujl1 committed
855
856
857
858
            fix_req_ids=self.fix_req_ids,
            fix_sampled_token_ids=self.fix_sampled_token_ids,
            fix_draft_tokens_ids=fix_draft_token_ids,
            fix_draft_req_ids=fix_draft_req_ids,
859
860
861
            is_output_valid=is_output_valid
        )
        return model_output