gpu_model_runner.py 42.4 KB
Newer Older
1
from typing import Any, Optional, Union
lizhigong's avatar
lizhigong committed
2
3
import torch
import numpy as np
4
from vllm import envs
5
from vllm.distributed.kv_transfer.kv_transfer_state import get_kv_transfer_group, has_kv_transfer_group
6
7
8
9
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
from vllm.forward_context import set_forward_context
from vllm.sequence import IntermediateTensors
from vllm.utils import async_tensor_h2d, round_up
lizhigong's avatar
lizhigong committed
10
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
11
12
13
14
15
16
17
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.eagle import EagleProposer
from vllm.v1.spec_decode.medusa import MedusaProposer
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
lizhigong's avatar
lizhigong committed
18
from vllm.v1.worker.block_table import BlockTable
19
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
20
from vllm.zero_overhead.v1.eagle import V1ZeroEagleProposer
lizhigong's avatar
lizhigong committed
21
22
from vllm.zero_overhead.v1.outputs import ZeroV1ModelRunnerOutput
from vllm.profiler.prof import profile
23
from vllm.two_batch_overlap.v1.model_input_split_v1 import tbo_split_and_execute_model
王敏's avatar
王敏 committed
24
from vllm.v1.spec_decode.utils import DraftProbs
25
from vllm.zero_overhead.utils import fused_update_input_ids_impl
lizhigong's avatar
lizhigong committed
26

27
28
29
class V1ZeroModelRunner(GPUModelRunner):
    def __init__(self, vllm_config, device):
        super().__init__(vllm_config, device)
lizhigong's avatar
lizhigong committed
30
31
32
33
34
        self.last_sampled_token_ids = None
        self.last_sampled_req_ids = []
        self.last_sampled_token_lens = []
        self.last_sampler_event = torch.cuda.Event(enable_timing=False)
        self.last_sampler_host_tokens = None
35
        self.token_ids_cpu_fix_record = []
36
37
38
        self.last_draft_token_ids = None
        self.last_draft_host_tokens = None
        self.last_draft_event = torch.cuda.Event(enable_timing=False)
39
40
        self.spec_sampler_event = torch.cuda.Event(enable_timing=False)
        self.spec_scheduler_max_num_tokens = 0
jujl1's avatar
jujl1 committed
41
42
43
        self.fix_req_ids = None
        self.fix_sampled_token_ids = None

lizhigong's avatar
lizhigong committed
44
45
46
        if hasattr(self, 'drafter') and isinstance(self.drafter, EagleProposer):
            self.drafter = V1ZeroEagleProposer(self.vllm_config, self.device,
                                                self) 
lizhigong's avatar
lizhigong committed
47
    
lizhigong's avatar
lizhigong committed
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
    def _prepare_inputs(
        self,
        scheduler_output: "SchedulerOutput",
    ) -> tuple[dict[str, Any], bool, torch.Tensor,
               Optional[SpecDecodeMetadata], np.ndarray]:
        """
        :return: tuple[
            attn_metadata: layer-to-attention_metadata mapping,
            attention_cuda_graphs: whether attention can run in cudagraph
            logits_indices, spec_decode_metadata
        ]
        """
        total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
        assert total_num_scheduled_tokens > 0
        num_reqs = self.input_batch.num_reqs
        assert num_reqs > 0

        # OPTIMIZATION: Start copying the block table first.
        # This way, we can overlap the copy with the following CPU operations.
        self.input_batch.block_table.commit(num_reqs)

        # Get the number of scheduled tokens for each request.
        req_ids = self.input_batch.req_ids
        tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids]
        num_scheduled_tokens = np.array(tokens, dtype=np.int32)
        max_num_scheduled_tokens = max(tokens)
74
        self.spec_scheduler_max_num_tokens = max_num_scheduled_tokens
lizhigong's avatar
lizhigong committed
75
76
77
78
79
80
81
82
83
84
85

        # Get request indices.
        # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
        req_indices = np.repeat(self.arange_np[:num_reqs],
                                num_scheduled_tokens)

        # cu_num_tokens: [2, 5, 3] -> [2, 7, 10]
        # arange: [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
        cu_num_tokens, arange = self._get_cumsum_and_arange(
            num_scheduled_tokens)

jujl1's avatar
jujl1 committed
86

87
88
        if (envs.VLLM_ZERO_OVERHEAD_ENHANCE and
            self.speculative_config and self.last_sampler_host_tokens != None):
89
90
91
92
93
94
95
96
97
98
            self.fix_req_ids = self.last_sampled_req_ids
            self.last_sampler_event.synchronize()  # 等上一轮主模型结束
            num_gen_tokens = self.last_sampler_host_tokens.shape[-1]
            if num_gen_tokens == 1:
                self.fix_sampled_token_ids = self.last_sampler_host_tokens.tolist()
            else:
                self.fix_sampled_token_ids = self.rejection_sampler.parse_output(
                    self.last_sampler_host_tokens,
                    self.input_batch.vocab_size,
                )
jujl1's avatar
jujl1 committed
99
            for req_idx, start_idx, end_idx in self.token_ids_cpu_fix_record:
100
                if start_idx == -1:
101
                    self.fix_sampled_token_ids[req_idx].clear()
102
                    continue
103
104
105
106
107
108
109
110
                num_accepted_tokens = len(self.fix_sampled_token_ids[req_idx])
                req_id = self.fix_req_ids[req_idx]
                if req_id in self.input_batch.req_ids:
                    new_req_idx = self.input_batch.req_ids.index(req_id)
                    new_end_idx = start_idx + num_accepted_tokens
                    # # 更新token统计数据
                    self.input_batch.num_tokens_no_spec[new_req_idx] = new_end_idx
                    self.input_batch.num_tokens[new_req_idx] = new_end_idx
111
112
                    self.input_batch.token_ids_cpu[new_req_idx, start_idx:new_end_idx] = (
                        self.fix_sampled_token_ids)[req_idx]
113
114
115
116
                    self.input_batch.num_computed_tokens_cpu[new_req_idx] -= (end_idx - new_end_idx)
                if req_id in self.requests:
                    req_state = self.requests[req_id]
                    req_state.output_token_ids.extend(self.fix_sampled_token_ids[req_idx])
jujl1's avatar
jujl1 committed
117

lizhigong's avatar
lizhigong committed
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
        # Get positions.
        positions_np = self.positions_np[:total_num_scheduled_tokens]
        np.add(self.input_batch.num_computed_tokens_cpu[req_indices],
               arange,
               out=positions_np)

        # Calculate M-RoPE positions.
        # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
        if self.uses_mrope:
            self._calc_mrope_positions(scheduler_output)

        # Get token indices.
        # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
        # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2]
        # where M is the max_model_len.
        token_indices = (positions_np +
                         req_indices * self.input_batch.token_ids_cpu.shape[1])

        # NOTE(woosuk): We use torch.index_select instead of np.take here
        # because torch.index_select is much faster than np.take for large
        # tensors.
        torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(),
                           0,
                           torch.from_numpy(token_indices),
                           out=self.input_ids_cpu[:total_num_scheduled_tokens])

        # Calculate the slot mapping for each KV cache group.
        for kv_cache_group_id, kv_cache_group_spec in enumerate(
                self.kv_cache_config.kv_cache_groups):
            block_size = kv_cache_group_spec.kv_cache_spec.block_size
            block_table: BlockTable = self.input_batch.block_table[
                kv_cache_group_id]
            # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
            # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1]
            # where K is the max_num_blocks_per_req and the block size is 2.
            # NOTE(woosuk): We can't simply use `token_indices // block_size`
            # here because M (max_model_len) is not necessarily divisible by
            # block_size.
            block_table_indices = (
                req_indices * block_table.max_num_blocks_per_req +
                positions_np // block_size)
            block_table_cpu = block_table.get_cpu_tensor()
            block_numbers = block_table_cpu.flatten(
            )[block_table_indices].numpy()
            block_offsets = positions_np % block_size
            np.add(
                block_numbers * block_size,
                block_offsets,
                out=block_table.slot_mapping_np[:total_num_scheduled_tokens])

        # Prepare the attention metadata.
        self.query_start_loc_np[0] = 0
        self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens

        self.seq_lens_np[:num_reqs] = (
            self.input_batch.num_computed_tokens_cpu[:num_reqs] +
            num_scheduled_tokens)

        # Copy the tensors to the GPU.
        self.input_ids[:total_num_scheduled_tokens].copy_(
            self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True)
        
        self.zero_prepare_inputs(scheduler_output, self.input_ids)

        if self.uses_mrope:
            # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
            self.mrope_positions[:, :total_num_scheduled_tokens].copy_(
                self.mrope_positions_cpu[:, :total_num_scheduled_tokens],
                non_blocking=True)
        else:
            # Common case (1D positions)
            self.positions[:total_num_scheduled_tokens].copy_(
                self.positions_cpu[:total_num_scheduled_tokens],
                non_blocking=True)

        self.query_start_loc[:num_reqs + 1].copy_(
            self.query_start_loc_cpu[:num_reqs + 1], non_blocking=True)
        self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs],
                                       non_blocking=True)

        # Fill unused with -1. Needed for reshape_and_cache
        self.seq_lens[num_reqs:].fill_(0)
        # Note: pad query_start_loc to be non-decreasing, as kernels
        # like FlashAttention requires that
        self.query_start_loc[num_reqs + 1:].fill_(
            self.query_start_loc_cpu[num_reqs].item())

        query_start_loc = self.query_start_loc[:num_reqs + 1]
        seq_lens = self.seq_lens[:num_reqs]

        common_attn_metadata = CommonAttentionMetadata(
            query_start_loc=query_start_loc,
            seq_lens=seq_lens,
            num_reqs=num_reqs,
            num_actual_tokens=total_num_scheduled_tokens,
            max_query_len=max_num_scheduled_tokens,
        )

        attn_metadata: dict[str, Any] = {}
        # Prepare the attention metadata for each KV cache group and make layers
        # in the same group share the same metadata.
        for kv_cache_group_id, kv_cache_group_spec in enumerate(
                self.kv_cache_config.kv_cache_groups):

            # Prepare for cascade attention if enabled & beneficial.
            common_prefix_len = 0
            builder = self.attn_metadata_builders[kv_cache_group_id]
            if self.cascade_attn_enabled:
                common_prefix_len = self._compute_cascade_attn_prefix_len(
                    num_scheduled_tokens,
                    scheduler_output.
                    num_common_prefix_blocks[kv_cache_group_id],
                    kv_cache_group_spec.kv_cache_spec,
                    builder,
                )

            attn_metadata_i = (builder.build(
                common_prefix_len=common_prefix_len,
                common_attn_metadata=common_attn_metadata,
            ))

            for layer_name in kv_cache_group_spec.layer_names:
                attn_metadata[layer_name] = attn_metadata_i

        attention_cuda_graphs = all(
            b.can_run_in_cudagraph(common_attn_metadata)
            for b in self.attn_metadata_builders)

        use_spec_decode = len(
            scheduler_output.scheduled_spec_decode_tokens) > 0
        if not use_spec_decode:
            # NOTE(woosuk): Due to chunked prefills, the batch may contain
            # partial requests. While we should not sample any token
            # from these partial requests, we do so for simplicity.
            # We will ignore the sampled tokens from the partial requests.
            # TODO: Support prompt logprobs.
            logits_indices = query_start_loc[1:] - 1
            spec_decode_metadata = None
        else:
            # Get the number of draft tokens for each request.
            # Iterate over the dictionary rather than all requests since not all
            # requests have draft tokens.
            num_draft_tokens = np.zeros(num_reqs, dtype=np.int32)
            for req_id, draft_token_ids in (
                    scheduler_output.scheduled_spec_decode_tokens.items()):
                req_idx = self.input_batch.req_id_to_index[req_id]
                num_draft_tokens[req_idx] = len(draft_token_ids)

王敏's avatar
王敏 committed
266
267
268
269
            spec_decode_ids = None
            if envs.VLLM_REJECT_SAMPLE_OPT:
                spec_decode_ids = scheduler_output.scheduled_spec_decode_tokens.keys()

lizhigong's avatar
lizhigong committed
270
            spec_decode_metadata = self._calc_spec_decode_metadata(
王敏's avatar
王敏 committed
271
                num_draft_tokens, cu_num_tokens, spec_decode_ids)
lizhigong's avatar
lizhigong committed
272
273
274
275
276
277
278
279
280
            logits_indices = spec_decode_metadata.logits_indices

        # Hot-Swap lora model
        if self.lora_config:
            self.set_active_loras(self.input_batch, num_scheduled_tokens)

        return (attn_metadata, attention_cuda_graphs, logits_indices,
                spec_decode_metadata, num_scheduled_tokens)

281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
    def zero_prepare_inputs(self, scheduler_output, input_ids):
        req_ids = self.input_batch.req_ids
        update_req_indices = []
        input_ids_indices = []
        token_idx = 0
        if self.last_draft_token_ids is not None:
            draft_tokens_num = self.last_draft_token_ids.shape[1]
            for req_id in req_ids:
                if req_id in self.last_sampled_req_ids:
                    req_idx = self.last_sampled_req_ids.index(req_id) * draft_tokens_num
                    for num_idx in range(draft_tokens_num):
                        update_req_indices.append(req_idx + num_idx)
                        input_ids_indices.append(token_idx + num_idx + 1)
                token_idx += draft_tokens_num + 1
            if len(update_req_indices) > 0:
                update_req_indices_tensor = async_tensor_h2d(update_req_indices, torch.int32,
                                                            self.device,
                                                            True)
                input_ids_indices_tensor = async_tensor_h2d(input_ids_indices, torch.int32,
                                                            self.device,
                                                            True)
                last_draft_token_ids = self.last_draft_token_ids.flatten().to(torch.int)
                input_ids[input_ids_indices_tensor] = last_draft_token_ids[update_req_indices_tensor]
304

305
306
307
308
        update_req_indices = []
        input_ids_indices = []
        token_idx = 0
        if self.last_sampled_token_ids is not None:
309
            sampled_tokens_num = 1 if self.speculative_config else self.last_sampled_token_ids.shape[1]
310
311
            for req_id in req_ids:
                if req_id in self.last_sampled_req_ids:
312
                    req_idx = self.last_sampled_req_ids.index(req_id) * sampled_tokens_num
313
314
315
316
317
318
319
320
321
322
                    update_req_indices.append(req_idx)
                    input_ids_indices.append(token_idx)
                token_idx += scheduler_output.num_scheduled_tokens[req_id]
            if len(update_req_indices) > 0:
                update_req_indices_tensor = async_tensor_h2d(update_req_indices, torch.int32,
                                                            self.device,
                                                            True)
                input_ids_indices_tensor = async_tensor_h2d(input_ids_indices, torch.int32,
                                                            self.device,
                                                            True)
323
                if envs.VLLM_ZERO_OVERHEAD_ENHANCE and self.speculative_config:
324
325
326
327
                    fused_update_input_ids_impl(self.last_sampled_token_ids,input_ids,
                                                update_req_indices_tensor,input_ids_indices_tensor)
                else:
                    last_sampled_token_ids = self.last_sampled_token_ids.flatten()
328
329
330
                    for i in range(sampled_tokens_num):
                        input_ids[input_ids_indices_tensor + i] = (
                            last_sampled_token_ids)[update_req_indices_tensor + i]
331
332
333
334

    def propose_draft_token_ids(
        self,
        scheduler_output: "SchedulerOutput",
335
336
        num_accepted_tokens_tensor: torch.Tensor,
        sampled_token_ids: torch.Tensor,
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
        sampling_metadata: SamplingMetadata,
        hidden_states: torch.Tensor,
        sample_hidden_states: torch.Tensor,
        aux_hidden_states: Optional[torch.Tensor],
        spec_decode_metadata: Optional[SpecDecodeMetadata],
        attn_metadata: dict[str, Any],
    ) -> list[list[int]]:
        num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
        if self.speculative_config.method == "ngram":
            assert isinstance(self.drafter, NgramProposer)
            spec_token_ids = self.propose_ngram_draft_token_ids(
                sampled_token_ids)
        elif self.speculative_config.method == "medusa":
            assert isinstance(self.drafter, MedusaProposer)
            if sample_hidden_states.shape[0] == len(sampled_token_ids):
                # The input to the target model does not include draft tokens.
                hidden_states = sample_hidden_states
            else:
                indices = []
                offset = 0
                for num_draft, tokens in zip(
                        spec_decode_metadata.num_draft_tokens,
                        sampled_token_ids):
                    indices.append(offset + len(tokens) - 1)
                    offset += num_draft + 1
                indices = torch.tensor(indices, device=self.device)
                hidden_states = sample_hidden_states[indices]

            spec_token_ids = self.drafter.propose(
                target_hidden_states=hidden_states,
                sampling_metadata=sampling_metadata,
            )
        elif self.speculative_config.use_eagle():
            assert isinstance(self.drafter, EagleProposer)
            # TODO(woosuk): Refactor the loop.
372
373
            row_indices = torch.arange(sampled_token_ids.size(0), device=sampled_token_ids.device)
            next_token_ids = sampled_token_ids[row_indices, num_accepted_tokens_tensor].flatten()
374
375
376
377
378
379
380
381
382
383
384
            # At this moment, we assume all eagle layers belong to the same KV
            # cache group, thus using the same attention metadata.
            eagle_attn_metadata = attn_metadata[
                self.drafter.attn_layer_names[0]]

            # NOTE: deepseek_mtp uses MLA which does not have `block_table`
            if hasattr(eagle_attn_metadata, "block_table"):
                block_table = eagle_attn_metadata.block_table
            else:
                block_table = None

385
            spec_scheduler_max_num_tokens = self.spec_scheduler_max_num_tokens
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
            if spec_decode_metadata is None:
                # input_ids can be None for multimodal models.
                target_token_ids = self.input_ids[:num_scheduled_tokens]
                # TODO(woosuk): Support M-RoPE.
                target_positions = self.positions[:num_scheduled_tokens]
                if self.use_aux_hidden_state_outputs:
                    target_hidden_states = torch.cat(
                        [h[:num_scheduled_tokens] for h in aux_hidden_states],
                        dim=-1)
                else:
                    target_hidden_states = hidden_states[:num_scheduled_tokens]
                target_slot_mapping = eagle_attn_metadata.slot_mapping
                cu_num_tokens = eagle_attn_metadata.query_start_loc
            else:
                # TODO(woosuk): Refactor this.
                cu_num_tokens, token_indices = self.drafter.prepare_inputs(
                    eagle_attn_metadata.query_start_loc,
lizhigong's avatar
lizhigong committed
403
                    num_accepted_tokens_tensor,
404
                )
405
                spec_scheduler_max_num_tokens = 1
406
407
408
409
410
411
412
413
414
415
                target_token_ids = self.input_ids[token_indices]
                # TODO(woosuk): Support M-RoPE.
                target_positions = self.positions[token_indices]
                if self.use_aux_hidden_state_outputs:
                    target_hidden_states = torch.cat(
                        [h[token_indices] for h in aux_hidden_states], dim=-1)
                else:
                    target_hidden_states = hidden_states[token_indices]
                target_slot_mapping = eagle_attn_metadata.slot_mapping[
                    token_indices]
416
            self.drafter.spec_scheduler_max_num_tokens = spec_scheduler_max_num_tokens
王敏's avatar
王敏 committed
417
            draft_result = self.drafter.propose(
418
419
420
421
422
423
424
425
                target_token_ids=target_token_ids,
                target_positions=target_positions,
                target_hidden_states=target_hidden_states,
                target_slot_mapping=target_slot_mapping,
                next_token_ids=next_token_ids,
                cu_num_tokens=cu_num_tokens,
                block_table=block_table,
                sampling_metadata=sampling_metadata,
426
                decoding=spec_decode_metadata is not None,
427
            )
王敏's avatar
王敏 committed
428
429
430
431
432
433
434
435
436
437
438
439
440

            if not envs.VLLM_REJECT_SAMPLE_OPT:
                draft_token_ids = draft_result
            else:
                draft_token_ids, draft_probs = draft_result
                draft_req_ids = list(scheduler_output.num_scheduled_tokens.keys())
                
                if self.draft_probs is None:
                    self.draft_probs = DraftProbs(
                        draft_probs, draft_req_ids)
                else:
                    self.draft_probs.update(draft_probs, draft_req_ids)

441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
            spec_token_ids = np.ones(draft_token_ids.shape, dtype=int).tolist()
            self.last_draft_token_ids = draft_token_ids
            self.last_draft_host_tokens = draft_token_ids.to('cpu', non_blocking=True)
            self.last_draft_event.record()
        return spec_token_ids

    @torch.inference_mode()
    def execute_model(
        self,
        scheduler_output: "SchedulerOutput",
        intermediate_tensors: Optional[IntermediateTensors] = None,
    ) -> Union[ModelRunnerOutput, IntermediateTensors]:
        self._update_states(scheduler_output)
        if not scheduler_output.total_num_scheduled_tokens:
            if not has_kv_transfer_group():
                # Return empty ModelRunnerOutput if there's no work to do.
                return EMPTY_MODEL_RUNNER_OUTPUT

            return self.kv_connector_no_forward(scheduler_output)

        # Prepare the decoder inputs.
        (attn_metadata, attention_cuda_graphs, logits_indices,
         spec_decode_metadata,
         num_scheduled_tokens_np) = (self._prepare_inputs(scheduler_output))
        num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
466
467
        
        # make sure that the padded length is divisible by attn_tp_size because we may need reduce-scatter across attn_tp dim.
468
        if self.ep_sp or self.enable_dp_attention:
469
470
471
472
473
474
475
            num_input_tokens = round_up(num_scheduled_tokens, tp_size)
            if (self.use_cuda_graph
                    and num_input_tokens <= self.cudagraph_batch_sizes[-1]):
                # Use piecewise CUDA graphs.
                # Add padding to the batch size.
                num_input_tokens = self.vllm_config.pad_for_cudagraph(
                    num_input_tokens)
476
        else:
477
478
479
480
481
482
            if (self.use_cuda_graph
                    and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]):
                # Use piecewise CUDA graphs.
                # Add padding to the batch size.
                num_input_tokens = self.vllm_config.pad_for_cudagraph(
                    num_scheduled_tokens)
483
            else:
484
485
486
487
488
489
490
491
492
                # Eager mode.
                # Pad tokens to multiple of tensor_parallel_size when
                # enabled collective fusion for SP
                tp_size = self.vllm_config.parallel_config.tensor_parallel_size
                if self.compilation_config.pass_config. \
                        enable_sequence_parallelism and tp_size > 1:
                    num_input_tokens = round_up(num_scheduled_tokens, tp_size)
                else:
                    num_input_tokens = num_scheduled_tokens
王敏's avatar
王敏 committed
493

494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
        # Padding for DP
        num_pad, num_tokens_across_dp = self.get_dp_padding(num_input_tokens)
        num_input_tokens += num_pad

        # _prepare_inputs may reorder the batch, so we must gather multi
        # modal outputs after that to ensure the correct order
        if self.is_multimodal_model:
            # Run the multimodal encoder if any.
            self._execute_mm_encoder(scheduler_output)
            mm_embeds = self._gather_mm_embeddings(scheduler_output)
        else:
            mm_embeds = []

        if self.is_multimodal_model and get_pp_group().is_first_rank:
            # NOTE(woosuk): To unify token ids and soft tokens (vision
            # embeddings), we always use embeddings (rather than token ids)
            # as input to the multimodal model, even when the input is text.
            input_ids = self.input_ids[:num_scheduled_tokens]
            if mm_embeds:
                inputs_embeds = self.model.get_input_embeddings(
                    input_ids, mm_embeds)
            else:
                inputs_embeds = self.model.get_input_embeddings(input_ids)
            # TODO(woosuk): Avoid the copy. Optimize.
            self.inputs_embeds[:num_scheduled_tokens].copy_(inputs_embeds)
            inputs_embeds = self.inputs_embeds[:num_input_tokens]
            input_ids = None
        else:
            # For text-only models, we use token ids as input.
            # While it is possible to use embeddings as input just like the
            # multimodal models, it is not desirable for performance since
            # then the embedding layer is not included in the CUDA graph.
            input_ids = self.input_ids[:num_input_tokens]
            inputs_embeds = None
        if self.uses_mrope:
            positions = self.mrope_positions[:, :num_input_tokens]
        else:
            positions = self.positions[:num_input_tokens]

        if get_pp_group().is_first_rank:
            intermediate_tensors = None
        else:
            intermediate_tensors = self.sync_and_slice_intermediate_tensors(
                num_input_tokens, intermediate_tensors, True)

        # Some attention backends only support CUDA Graphs in pure decode.
        # If attention doesn't support CUDA Graphs for this batch, but we
        # compiled with full CUDA graphs, we have to skip them entirely.
        skip_cuda_graphs = self.full_cuda_graph and not attention_cuda_graphs
543
        if envs.VLLM_ENABLE_TBO and (not self.use_cuda_graph or skip_cuda_graphs):
544
545
546
            model_output, finished_sending, finished_recving = \
                 tbo_split_and_execute_model(self, attn_metadata, num_input_tokens,
                                             num_tokens_across_dp, input_ids, positions,
547
548
                                             inputs_embeds, scheduler_output, intermediate_tensors, 
                                             skip_cuda_graphs)
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
        else:
            # Run the model.
            # Use persistent buffers for CUDA graphs.
            with set_forward_context(
                    attn_metadata,
                    self.vllm_config,
                    num_tokens=num_input_tokens,
                    num_tokens_across_dp=num_tokens_across_dp,
                    skip_cuda_graphs=skip_cuda_graphs,
            ):
                self.maybe_setup_kv_connector(scheduler_output)

                model_output = self.model(
                    input_ids=input_ids,
                    positions=positions,
                    intermediate_tensors=intermediate_tensors,
                    inputs_embeds=inputs_embeds,
                )

                self.maybe_wait_for_kv_save()
                finished_sending, finished_recving = (
                    self.get_finished_kv_transfers(scheduler_output))
        if self.use_aux_hidden_state_outputs:
            hidden_states, aux_hidden_states = model_output
        else:
            hidden_states = model_output
            aux_hidden_states = None

        # Broadcast PP output for external_launcher (torchrun)
        # to make sure we are synced across pp ranks
        # TODO: Support overlapping mirco-batches
        # https://github.com/vllm-project/vllm/issues/18019
        broadcast_pp_output = \
            self.parallel_config.distributed_executor_backend \
            == "external_launcher" and len(get_pp_group().ranks) > 0
        if not get_pp_group().is_last_rank:
            # For mid-pipeline stages, return the hidden states.
            if not broadcast_pp_output:
                return hidden_states
            assert isinstance(hidden_states, IntermediateTensors)
            get_pp_group().send_tensor_dict(hidden_states.tensors,
                                            all_gather_group=get_tp_group())
            logits = None
        else:
            if self.input_batch.pooling_params:
                return self._pool(hidden_states, num_scheduled_tokens,
                                  num_scheduled_tokens_np, finished_sending,
                                  finished_recving)

            sample_hidden_states = hidden_states[logits_indices]
            logits = self.model.compute_logits(sample_hidden_states, None)
        if broadcast_pp_output:
            model_output_broadcast_data = {
                "logits": logits.contiguous(),
            } if logits is not None else {}
            model_output_broadcast_data = get_pp_group().broadcast_tensor_dict(
                model_output_broadcast_data, src=len(get_pp_group().ranks) - 1)
            assert model_output_broadcast_data is not None
            logits = model_output_broadcast_data["logits"]

        # Apply structured output bitmasks if present
        if scheduler_output.grammar_bitmask is not None:
            self.apply_grammar_bitmask(scheduler_output, logits)

        # Sample the next token and get logprobs if needed.
        sampling_metadata = self.input_batch.sampling_metadata
        if spec_decode_metadata is None:
            sampler_output = self.sampler(
                logits=logits,
                sampling_metadata=sampling_metadata,
            )
        else:
            # When indexing with a tensor (bonus_logits_indices), PyTorch
            # creates a new tensor with separate storage from the original
            # logits tensor. This means any in-place operations on bonus_logits
            # won't affect the original logits tensor.
            assert logits is not None
王敏's avatar
王敏 committed
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
            if not envs.VLLM_REJECT_SAMPLE_OPT:
                bonus_logits = logits[spec_decode_metadata.bonus_logits_indices]
                sampler_output = self.sampler(
                    logits=bonus_logits,
                    sampling_metadata=sampling_metadata,
                )
                bonus_token_ids = sampler_output.sampled_token_ids

                # Just like `bonus_logits`, `target_logits` is a new tensor with
                # separate storage from the original `logits` tensor. Therefore,
                # it is safe to update `target_logits` in place.
                target_logits = logits[spec_decode_metadata.target_logits_indices]
                output_token_ids = self.rejection_sampler(
                    spec_decode_metadata,
                    None,  # draft_probs
                    target_logits,
                    bonus_token_ids,
                    sampling_metadata,
                )
                sampler_output.sampled_token_ids = output_token_ids
            else:
王敏's avatar
王敏 committed
647
648
                sampling_metadata.all_greedy = True
                sampling_metadata.all_random = False
王敏's avatar
王敏 committed
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
                sampler_output = self.sampler(
                    logits=logits,
                    sampling_metadata=sampling_metadata,
                )
                target_token_ids = sampler_output.sampled_token_ids[spec_decode_metadata.target_logits_indices]
                target_logits = logits[spec_decode_metadata.target_logits_indices]

                bonus_token_ids = sampler_output.sampled_token_ids[spec_decode_metadata.bonus_logits_indices]

                output_token_ids = self.rejection_sampler(
                    spec_decode_metadata,
                    self.draft_probs.get_probs(spec_decode_metadata.spec_decode_ids),
                    target_logits,
                    target_token_ids,
                    bonus_token_ids,
                    sampling_metadata,
                )
                sampler_output.sampled_token_ids = output_token_ids
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698

        num_nans_in_logits = {}
        if envs.VLLM_COMPUTE_NANS_IN_LOGITS:
            num_nans_in_logits = self._get_nans_in_logits(logits)

        # TODO(woosuk): The following loop can be slow since it iterates over
        # the requests one by one. Optimize.
        discard_sampled_tokens_req_indices = []
        for i, req_id in enumerate(self.input_batch.req_ids):
            req_state = self.requests[req_id]
            seq_len = (req_state.num_computed_tokens +
                       scheduler_output.num_scheduled_tokens[req_id])
            if seq_len < req_state.num_tokens:
                # Ignore the sampled token for partial prefills.
                # Rewind the generator state as if the token was not sampled.
                # This relies on cuda-specific torch-internal impl details
                generator = self.input_batch.generators.get(i)
                if generator is not None:
                    generator.set_offset(generator.get_offset() - 4)
                # Record the index of the request that should not be sampled,
                # so that we could clear the sampled tokens before returning.
                discard_sampled_tokens_req_indices.append(i)

        # NOTE: GPU -> CPU Sync happens here.
        # Move as many CPU operations as possible before this sync point.
        logprobs_tensors = sampler_output.logprobs_tensors
        logprobs_lists = logprobs_tensors.tolists() \
            if logprobs_tensors is not None else None

        # Compute prompt logprobs if needed.
        prompt_logprobs_dict = self._get_prompt_logprobs_dict(
            hidden_states[:num_scheduled_tokens],
lizhigong's avatar
lizhigong committed
699
700
701
            scheduler_output,
        )

702
703
704
        fix_draft_token_ids = None
        fix_draft_req_ids = self.last_sampled_req_ids
        is_output_valid = False
705
706
        # Get the valid generated tokens.
        sampled_token_ids = sampler_output.sampled_token_ids
707
708
        over_head_enhance = (envs.VLLM_ZERO_OVERHEAD_ENHANCE and
                             self.speculative_config is not None)
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
        if over_head_enhance:
            # if not self.speculative_config:
            #     self.fix_req_ids = self.last_sampled_req_ids
            #     if self.last_sampler_host_tokens is not None:
            #         self.last_sampler_event.synchronize()
            #         self.fix_sampled_token_ids = self.last_sampler_host_tokens.tolist()
            #         for req_idx, start_idx, end_idx in self.token_ids_cpu_fix_record:
            #             if start_idx == -1:
            #                 self.fix_sampled_token_ids[req_idx].clear()
            #                 continue
            #             req_id = self.fix_req_ids[req_idx]
            #             if req_id in self.input_batch.req_ids:
            #                 new_req_idx = self.input_batch.req_ids.index(req_id)
            #                 self.input_batch.token_ids_cpu[new_req_idx, start_idx:end_idx] = self.fix_sampled_token_ids[req_idx]
            #     for req_idx, req_id in enumerate(self.fix_req_ids):
            #         if req_id in self.requests:
            #             req_state = self.requests[req_id]
            #             token_idx = self.last_sampled_token_lens[req_idx]
            #             if token_idx == -1:
            #                 continue
            #             fix_len = len(self.fix_sampled_token_ids[req_idx])
            #             req_state.output_token_ids[token_idx:token_idx + fix_len] = self.fix_sampled_token_ids[req_idx]
            self.last_sampler_host_tokens = None
            self.last_sampled_token_ids = None
            self.last_sampler_host_tokens = sampled_token_ids.to('cpu', non_blocking=True)
            self.last_sampler_event.record()
            self.last_sampled_token_ids = sampled_token_ids
            valid_sampled_token_ids = np.ones(sampled_token_ids.shape, dtype=int).tolist()
737

738
739
740
741
742
        if not self.speculative_config:
            # Speculative decoding is not enabled.
            spec_token_ids = None
            fix_draft_req_ids = None
        else:
743
744
745
            if not over_head_enhance:
                sampled_token_ids_cpu = sampled_token_ids.to('cpu', non_blocking=True)
                self.spec_sampler_event.record()
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
            if self.last_draft_host_tokens is not None:
                self.last_draft_event.synchronize()
                fix_draft_token_ids = self.last_draft_host_tokens.tolist()

            mask = (sampled_token_ids == -1)
            mask_int = mask.int()
            first_neg_one_indices = torch.argmax(mask_int, dim=1)
            num_accepted_tokens_tensor = torch.where(torch.any(mask, dim=1), first_neg_one_indices, sampled_token_ids.size(1)) - 1
            spec_token_ids = self.propose_draft_token_ids(
                scheduler_output,
                num_accepted_tokens_tensor,
                sampled_token_ids,
                sampling_metadata,
                hidden_states,
                sample_hidden_states,
                aux_hidden_states,
                spec_decode_metadata,
                attn_metadata,
            )

766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
        if not over_head_enhance:
            if self.speculative_config:
                self.spec_sampler_event.synchronize()
                max_gen_len = sampled_token_ids.shape[-1]
                if max_gen_len == 1:
                    valid_sampled_token_ids = sampled_token_ids_cpu.tolist()
                else:
                    # Includes spec decode tokens.
                    valid_sampled_token_ids = self.rejection_sampler.parse_output(
                        sampled_token_ids_cpu,
                        self.input_batch.vocab_size,
                    )
                    self.last_sampler_host_tokens = None
                    self.last_sampled_token_ids = None
                is_output_valid = True
            else:
                # No spec decode tokens.
                self.fix_req_ids = self.last_sampled_req_ids
                if self.last_sampler_host_tokens != None:
                    self.last_sampler_event.synchronize()
                    self.fix_sampled_token_ids = self.last_sampler_host_tokens.tolist()
                    for req_idx, start_idx, end_idx in self.token_ids_cpu_fix_record:
                        if start_idx == -1:
                            continue
                        req_id = self.fix_req_ids[req_idx]
                        if req_id in self.input_batch.req_ids:
                            new_req_idx = self.input_batch.req_ids.index(req_id)
                            self.input_batch.token_ids_cpu[new_req_idx, start_idx:end_idx] = self.fix_sampled_token_ids[req_idx]
                for req_idx, req_id in enumerate(self.fix_req_ids):
                    if req_id in self.requests:
                        req_state = self.requests[req_id]
                        token_idx = self.last_sampled_token_lens[req_idx]
                        if token_idx == -1:
799
                            self.fix_sampled_token_ids[req_idx].clear()
800
801
802
803
804
805
806
807
808
809
810
811
                            continue
                        fix_len = len(self.fix_sampled_token_ids[req_idx])
                        req_state.output_token_ids[token_idx:token_idx + fix_len] = self.fix_sampled_token_ids[req_idx]
                self.last_sampler_host_tokens = sampled_token_ids.to('cpu', non_blocking=True)
                self.last_sampler_event.record()
                self.last_sampled_token_ids = sampled_token_ids
                valid_sampled_token_ids = np.ones(sampled_token_ids.shape, dtype=int).tolist()

        # Mask out the sampled tokens that should not be sampled.
        for i in discard_sampled_tokens_req_indices:
            valid_sampled_token_ids[i].clear()

812
813
814
815
816
        # Cache the sampled tokens in the model runner, so that the scheduler
        # doesn't need to send them back.
        # NOTE(woosuk): As an exception, when using PP, the scheduler sends
        # the sampled tokens back, because there's no direct communication
        # between the first-stage worker and the last-stage worker.
817
        self.token_ids_cpu_fix_record.clear()
818
819
820
        self.last_sampled_req_ids = []
        self.last_sampled_token_lens = []
        for req_idx, sampled_ids in enumerate(valid_sampled_token_ids):
821
            req_id = self.input_batch.req_ids[req_idx]
822
            self.last_sampled_req_ids.append(req_id)
823
            cache_output_len = -1
824
            if not sampled_ids:
825
826
                self.last_sampled_token_lens.append(-1)
                self.token_ids_cpu_fix_record.append([req_idx, -1, -1])
827
                continue
828

829
830
831
832
833
834
835
836
837
            start_idx = self.input_batch.num_tokens_no_spec[req_idx]
            end_idx = start_idx + len(sampled_ids)
            assert end_idx <= self.max_model_len, (
                "Sampled token IDs exceed the max model length. "
                f"Total number of tokens: {end_idx} > max_model_len: "
                f"{self.max_model_len}")

            self.input_batch.token_ids_cpu[req_idx,
                                            start_idx:end_idx] = sampled_ids
838
            self.token_ids_cpu_fix_record.append([req_idx, start_idx, end_idx])
839
840
            self.input_batch.num_tokens_no_spec[req_idx] = end_idx
            self.input_batch.num_tokens[req_idx] = end_idx
841
            if not over_head_enhance and req_id in self.requests:
842
843
844
                req_state = self.requests[req_id]
                cache_output_len = len(req_state.output_token_ids)
                req_state.output_token_ids.extend(sampled_ids)
845
            self.last_sampled_token_lens.append(cache_output_len)
jujl1's avatar
jujl1 committed
846

847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
        # Clear KVConnector state after all KVs are generated.
        if has_kv_transfer_group():
            get_kv_transfer_group().clear_connector_metadata()

        self.eplb_step()

        model_output = ZeroV1ModelRunnerOutput(
            req_ids=self.input_batch.req_ids,
            req_id_to_index=self.input_batch.req_id_to_index,
            sampled_token_ids=valid_sampled_token_ids,
            spec_token_ids=spec_token_ids,
            logprobs=logprobs_lists,
            prompt_logprobs_dict=prompt_logprobs_dict,
            pooler_output=[],
            finished_sending=finished_sending,
            finished_recving=finished_recving,
            num_nans_in_logits=num_nans_in_logits,
jujl1's avatar
jujl1 committed
864
865
866
867
            fix_req_ids=self.fix_req_ids,
            fix_sampled_token_ids=self.fix_sampled_token_ids,
            fix_draft_tokens_ids=fix_draft_token_ids,
            fix_draft_req_ids=fix_draft_req_ids,
868
869
870
            is_output_valid=is_output_valid
        )
        return model_output