gpu_model_runner.py 26.5 KB
Newer Older
lizhigong's avatar
lizhigong committed
1

2
from typing import Any, Optional, Union
lizhigong's avatar
lizhigong committed
3
4
import torch
import numpy as np
5
from vllm import envs
lizhigong's avatar
lizhigong committed
6
from vllm.distributed.kv_transfer.kv_transfer_state import get_kv_transfer_group, has_kv_transfer_group
7
8
9
10
11
12
13
14
15
16
17
18
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
from vllm.forward_context import set_forward_context
from vllm.sequence import IntermediateTensors
from vllm.utils import async_tensor_h2d, round_up
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.eagle import EagleProposer
from vllm.v1.spec_decode.medusa import MedusaProposer
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
lizhigong's avatar
lizhigong committed
19
20
from vllm.zero_overhead.v1.outputs import ZeroV1ModelRunnerOutput
from vllm.profiler.prof import profile
21
from vllm.two_batch_overlap.v1.model_input_split_v1 import tbo_split_and_execute_model
lizhigong's avatar
lizhigong committed
22
23


24
25
26
class V1ZeroModelRunner(GPUModelRunner):
    def __init__(self, vllm_config, device):
        super().__init__(vllm_config, device)
lizhigong's avatar
lizhigong committed
27
28
29
30
31
32
        self.last_sampled_token_ids = None
        self.last_sampled_req_ids = []
        self.last_sampled_token_lens = []
        self.last_sampler_event = torch.cuda.Event(enable_timing=False)
        self.last_sampler_host_tokens = None
        self.token_ids_cpu_fix_recode = []
33
34
35
        self.last_draft_token_ids = None
        self.last_draft_host_tokens = None
        self.last_draft_event = torch.cuda.Event(enable_timing=False)
lizhigong's avatar
lizhigong committed
36
    
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
    def zero_prepare_inputs(self, scheduler_output, input_ids):
        req_ids = self.input_batch.req_ids
        update_req_indices = []
        input_ids_indices = []
        token_idx = 0
        if self.last_draft_token_ids is not None:
            draft_tokens_num = self.last_draft_token_ids.shape[1]
            for req_id in req_ids:
                if req_id in self.last_sampled_req_ids:
                    req_idx = self.last_sampled_req_ids.index(req_id) * draft_tokens_num
                    for num_idx in range(draft_tokens_num):
                        update_req_indices.append(req_idx + num_idx)
                        input_ids_indices.append(token_idx + num_idx + 1)
                token_idx += draft_tokens_num + 1
            if len(update_req_indices) > 0:
                update_req_indices_tensor = async_tensor_h2d(update_req_indices, torch.int32,
                                                            self.device,
                                                            True)
                input_ids_indices_tensor = async_tensor_h2d(input_ids_indices, torch.int32,
                                                            self.device,
                                                            True)
                last_draft_token_ids = self.last_draft_token_ids.flatten().to(torch.int)
                input_ids[input_ids_indices_tensor] = last_draft_token_ids[update_req_indices_tensor]
lizhigong's avatar
lizhigong committed
60

61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
        update_req_indices = []
        input_ids_indices = []
        token_idx = 0
        if self.last_sampled_token_ids is not None:
            sampled_tokens_num = self.last_sampled_token_ids.shape[1]
            for req_id in req_ids:
                if req_id in self.last_sampled_req_ids:
                    req_idx = self.last_sampled_req_ids.index(req_id) * sampled_tokens_num
                    update_req_indices.append(req_idx)
                    input_ids_indices.append(token_idx)
                token_idx += scheduler_output.num_scheduled_tokens[req_id]
            if len(update_req_indices) > 0:
                update_req_indices_tensor = async_tensor_h2d(update_req_indices, torch.int32,
                                                            self.device,
                                                            True)
                input_ids_indices_tensor = async_tensor_h2d(input_ids_indices, torch.int32,
                                                            self.device,
                                                            True)
                last_sampled_token_ids = self.last_sampled_token_ids.flatten()
                for i in range(sampled_tokens_num):
                    input_ids[input_ids_indices_tensor + i] = last_sampled_token_ids[update_req_indices_tensor + i]

    def propose_draft_token_ids(
        self,
        scheduler_output: "SchedulerOutput",
        sampled_token_ids: list[list[int]],
        sampling_metadata: SamplingMetadata,
        hidden_states: torch.Tensor,
        sample_hidden_states: torch.Tensor,
        aux_hidden_states: Optional[torch.Tensor],
        spec_decode_metadata: Optional[SpecDecodeMetadata],
        attn_metadata: dict[str, Any],
    ) -> list[list[int]]:
        num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
        if self.speculative_config.method == "ngram":
            assert isinstance(self.drafter, NgramProposer)
            spec_token_ids = self.propose_ngram_draft_token_ids(
                sampled_token_ids)
        elif self.speculative_config.method == "medusa":
            assert isinstance(self.drafter, MedusaProposer)
            if sample_hidden_states.shape[0] == len(sampled_token_ids):
                # The input to the target model does not include draft tokens.
                hidden_states = sample_hidden_states
            else:
                indices = []
                offset = 0
                for num_draft, tokens in zip(
                        spec_decode_metadata.num_draft_tokens,
                        sampled_token_ids):
                    indices.append(offset + len(tokens) - 1)
                    offset += num_draft + 1
                indices = torch.tensor(indices, device=self.device)
                hidden_states = sample_hidden_states[indices]

            spec_token_ids = self.drafter.propose(
                target_hidden_states=hidden_states,
                sampling_metadata=sampling_metadata,
            )
        elif self.speculative_config.use_eagle():
            assert isinstance(self.drafter, EagleProposer)
            # TODO(woosuk): Refactor the loop.
            if self.last_sampled_token_ids is not None:
                next_token_ids = self.last_sampled_token_ids.flatten()
            else:
                next_token_ids: list[int] = []
                for i, token_ids in enumerate(sampled_token_ids):
                    if token_ids:
                        # Common case.
                        next_token_id = token_ids[-1]
                    else:
                        # Partial prefill (rare case).
                        # Get the next token id from the request state.
                        req_id = self.input_batch.req_ids[i]
                        req_state = self.requests[req_id]
                        seq_len = (req_state.num_computed_tokens +
                                scheduler_output.num_scheduled_tokens[req_id])
                        next_token_id = req_state.get_token_id(seq_len)
                    next_token_ids.append(next_token_id)
                next_token_ids = torch.tensor(next_token_ids,
                                            dtype=torch.int32,
                                            device=self.device)
            # At this moment, we assume all eagle layers belong to the same KV
            # cache group, thus using the same attention metadata.
            eagle_attn_metadata = attn_metadata[
                self.drafter.attn_layer_names[0]]

            # NOTE: deepseek_mtp uses MLA which does not have `block_table`
            if hasattr(eagle_attn_metadata, "block_table"):
                block_table = eagle_attn_metadata.block_table
            else:
                block_table = None

            num_rejected_tokens = None
            if spec_decode_metadata is None:
                # input_ids can be None for multimodal models.
                target_token_ids = self.input_ids[:num_scheduled_tokens]
                # TODO(woosuk): Support M-RoPE.
                target_positions = self.positions[:num_scheduled_tokens]
                if self.use_aux_hidden_state_outputs:
                    target_hidden_states = torch.cat(
                        [h[:num_scheduled_tokens] for h in aux_hidden_states],
                        dim=-1)
                else:
                    target_hidden_states = hidden_states[:num_scheduled_tokens]
                target_slot_mapping = eagle_attn_metadata.slot_mapping
                cu_num_tokens = eagle_attn_metadata.query_start_loc
            else:
                # TODO(woosuk): Refactor this.
                num_draft_tokens = spec_decode_metadata.num_draft_tokens
                num_rejected_tokens = [
                    n + 1 - len(sampled_token_ids[i]) if n > 0 else 0
                    for i, n in enumerate(num_draft_tokens)
                ]
                num_rejected_tokens_tensor = async_tensor_h2d(
                    num_rejected_tokens,
                    dtype=torch.int32,
                    target_device=self.device,
                    pin_memory=True)
                num_tokens = num_scheduled_tokens - sum(num_rejected_tokens)
                cu_num_tokens, token_indices = self.drafter.prepare_inputs(
                    eagle_attn_metadata.query_start_loc,
                    num_rejected_tokens_tensor,
                    num_tokens,
                )
                target_token_ids = self.input_ids[token_indices]
                # TODO(woosuk): Support M-RoPE.
                target_positions = self.positions[token_indices]
                if self.use_aux_hidden_state_outputs:
                    target_hidden_states = torch.cat(
                        [h[token_indices] for h in aux_hidden_states], dim=-1)
                else:
                    target_hidden_states = hidden_states[token_indices]
                target_slot_mapping = eagle_attn_metadata.slot_mapping[
                    token_indices]
            draft_token_ids = self.drafter.propose(
                target_token_ids=target_token_ids,
                target_positions=target_positions,
                target_hidden_states=target_hidden_states,
                target_slot_mapping=target_slot_mapping,
                next_token_ids=next_token_ids,
                cu_num_tokens=cu_num_tokens,
                block_table=block_table,
                sampling_metadata=sampling_metadata,
                num_rejected_tokens=num_rejected_tokens
            )
            spec_token_ids = np.ones(draft_token_ids.shape, dtype=int).tolist()
            self.last_draft_token_ids = draft_token_ids
            self.last_draft_host_tokens = draft_token_ids.to('cpu', non_blocking=True)
            self.last_draft_event.record()
        return spec_token_ids

    @torch.inference_mode()
    def execute_model(
        self,
        scheduler_output: "SchedulerOutput",
        intermediate_tensors: Optional[IntermediateTensors] = None,
    ) -> Union[ModelRunnerOutput, IntermediateTensors]:
        self._update_states(scheduler_output)
        if not scheduler_output.total_num_scheduled_tokens:
            if not has_kv_transfer_group():
                # Return empty ModelRunnerOutput if there's no work to do.
                return EMPTY_MODEL_RUNNER_OUTPUT

            return self.kv_connector_no_forward(scheduler_output)

        # Prepare the decoder inputs.
        (attn_metadata, attention_cuda_graphs, logits_indices,
         spec_decode_metadata,
         num_scheduled_tokens_np) = (self._prepare_inputs(scheduler_output))
        num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
        if (self.use_cuda_graph
                and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]):
            # Use piecewise CUDA graphs.
            # Add padding to the batch size.
            num_input_tokens = self.vllm_config.pad_for_cudagraph(
                num_scheduled_tokens)
        else:
            # Eager mode.
            # Pad tokens to multiple of tensor_parallel_size when
            # enabled collective fusion for SP
            tp_size = self.vllm_config.parallel_config.tensor_parallel_size
            if self.compilation_config.pass_config. \
                enable_sequence_parallelism and tp_size > 1:
                num_input_tokens = round_up(num_scheduled_tokens, tp_size)
            else:
                num_input_tokens = num_scheduled_tokens

        # Padding for DP
        num_pad, num_tokens_across_dp = self.get_dp_padding(num_input_tokens)
        num_input_tokens += num_pad

        # _prepare_inputs may reorder the batch, so we must gather multi
        # modal outputs after that to ensure the correct order
        if self.is_multimodal_model:
            # Run the multimodal encoder if any.
            self._execute_mm_encoder(scheduler_output)
            mm_embeds = self._gather_mm_embeddings(scheduler_output)
        else:
            mm_embeds = []

        if self.is_multimodal_model and get_pp_group().is_first_rank:
            # NOTE(woosuk): To unify token ids and soft tokens (vision
            # embeddings), we always use embeddings (rather than token ids)
            # as input to the multimodal model, even when the input is text.
            input_ids = self.input_ids[:num_scheduled_tokens]
            if mm_embeds:
                inputs_embeds = self.model.get_input_embeddings(
                    input_ids, mm_embeds)
            else:
                inputs_embeds = self.model.get_input_embeddings(input_ids)
            # TODO(woosuk): Avoid the copy. Optimize.
            self.inputs_embeds[:num_scheduled_tokens].copy_(inputs_embeds)
            inputs_embeds = self.inputs_embeds[:num_input_tokens]
            input_ids = None
        else:
            # For text-only models, we use token ids as input.
            # While it is possible to use embeddings as input just like the
            # multimodal models, it is not desirable for performance since
            # then the embedding layer is not included in the CUDA graph.
            input_ids = self.input_ids[:num_input_tokens]
            inputs_embeds = None
        if self.uses_mrope:
            positions = self.mrope_positions[:, :num_input_tokens]
        else:
            positions = self.positions[:num_input_tokens]

        if get_pp_group().is_first_rank:
            intermediate_tensors = None
        else:
            intermediate_tensors = self.sync_and_slice_intermediate_tensors(
                num_input_tokens, intermediate_tensors, True)

        # Some attention backends only support CUDA Graphs in pure decode.
        # If attention doesn't support CUDA Graphs for this batch, but we
        # compiled with full CUDA graphs, we have to skip them entirely.
        skip_cuda_graphs = self.full_cuda_graph and not attention_cuda_graphs

        self.zero_prepare_inputs(scheduler_output, input_ids)

        if envs.VLLM_ENABLE_TBO and not self.use_cuda_graph:
            model_output, finished_sending, finished_recving = \
                 tbo_split_and_execute_model(self, attn_metadata, num_input_tokens,
                                             num_tokens_across_dp, input_ids, positions,
                                             inputs_embeds, scheduler_output, intermediate_tensors)
        else:
            # Run the model.
            # Use persistent buffers for CUDA graphs.
            with set_forward_context(
                    attn_metadata,
                    self.vllm_config,
                    num_tokens=num_input_tokens,
                    num_tokens_across_dp=num_tokens_across_dp,
                    skip_cuda_graphs=skip_cuda_graphs,
            ):
                self.maybe_setup_kv_connector(scheduler_output)

                model_output = self.model(
                    input_ids=input_ids,
                    positions=positions,
                    intermediate_tensors=intermediate_tensors,
                    inputs_embeds=inputs_embeds,
                )

                self.maybe_wait_for_kv_save()
                finished_sending, finished_recving = (
                    self.get_finished_kv_transfers(scheduler_output))
        if self.use_aux_hidden_state_outputs:
            hidden_states, aux_hidden_states = model_output
        else:
            hidden_states = model_output
            aux_hidden_states = None

        # Broadcast PP output for external_launcher (torchrun)
        # to make sure we are synced across pp ranks
        # TODO: Support overlapping mirco-batches
        # https://github.com/vllm-project/vllm/issues/18019
        broadcast_pp_output = \
            self.parallel_config.distributed_executor_backend \
            == "external_launcher" and len(get_pp_group().ranks) > 0
        if not get_pp_group().is_last_rank:
            # For mid-pipeline stages, return the hidden states.
            if not broadcast_pp_output:
                return hidden_states
            assert isinstance(hidden_states, IntermediateTensors)
            get_pp_group().send_tensor_dict(hidden_states.tensors,
                                            all_gather_group=get_tp_group())
            logits = None
        else:
            if self.input_batch.pooling_params:
                return self._pool(hidden_states, num_scheduled_tokens,
                                  num_scheduled_tokens_np, finished_sending,
                                  finished_recving)

            sample_hidden_states = hidden_states[logits_indices]
            logits = self.model.compute_logits(sample_hidden_states, None)
        if broadcast_pp_output:
            model_output_broadcast_data = {
                "logits": logits.contiguous(),
            } if logits is not None else {}
            model_output_broadcast_data = get_pp_group().broadcast_tensor_dict(
                model_output_broadcast_data, src=len(get_pp_group().ranks) - 1)
            assert model_output_broadcast_data is not None
            logits = model_output_broadcast_data["logits"]

        # Apply structured output bitmasks if present
        if scheduler_output.grammar_bitmask is not None:
            self.apply_grammar_bitmask(scheduler_output, logits)

        # Sample the next token and get logprobs if needed.
        sampling_metadata = self.input_batch.sampling_metadata
        if spec_decode_metadata is None:
            sampler_output = self.sampler(
                logits=logits,
                sampling_metadata=sampling_metadata,
            )
        else:
            # When indexing with a tensor (bonus_logits_indices), PyTorch
            # creates a new tensor with separate storage from the original
            # logits tensor. This means any in-place operations on bonus_logits
            # won't affect the original logits tensor.
            assert logits is not None
            bonus_logits = logits[spec_decode_metadata.bonus_logits_indices]
            sampler_output = self.sampler(
                logits=bonus_logits,
                sampling_metadata=sampling_metadata,
            )
            bonus_token_ids = sampler_output.sampled_token_ids

            # Just like `bonus_logits`, `target_logits` is a new tensor with
            # separate storage from the original `logits` tensor. Therefore,
            # it is safe to update `target_logits` in place.
            target_logits = logits[spec_decode_metadata.target_logits_indices]
            output_token_ids = self.rejection_sampler(
                spec_decode_metadata,
                None,  # draft_probs
                target_logits,
                bonus_token_ids,
                sampling_metadata,
            )
            sampler_output.sampled_token_ids = output_token_ids

        num_nans_in_logits = {}
        if envs.VLLM_COMPUTE_NANS_IN_LOGITS:
            num_nans_in_logits = self._get_nans_in_logits(logits)

        # TODO(woosuk): The following loop can be slow since it iterates over
        # the requests one by one. Optimize.
        discard_sampled_tokens_req_indices = []
        for i, req_id in enumerate(self.input_batch.req_ids):
            req_state = self.requests[req_id]
            seq_len = (req_state.num_computed_tokens +
                       scheduler_output.num_scheduled_tokens[req_id])
            if seq_len < req_state.num_tokens:
                # Ignore the sampled token for partial prefills.
                # Rewind the generator state as if the token was not sampled.
                # This relies on cuda-specific torch-internal impl details
                generator = self.input_batch.generators.get(i)
                if generator is not None:
                    generator.set_offset(generator.get_offset() - 4)
                # Record the index of the request that should not be sampled,
                # so that we could clear the sampled tokens before returning.
                discard_sampled_tokens_req_indices.append(i)

        # NOTE: GPU -> CPU Sync happens here.
        # Move as many CPU operations as possible before this sync point.
        logprobs_tensors = sampler_output.logprobs_tensors
        logprobs_lists = logprobs_tensors.tolists() \
            if logprobs_tensors is not None else None

        # Compute prompt logprobs if needed.
        prompt_logprobs_dict = self._get_prompt_logprobs_dict(
            hidden_states[:num_scheduled_tokens],
lizhigong's avatar
lizhigong committed
433
434
435
            scheduler_output,
        )

436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
        # Get the valid generated tokens.
        sampled_token_ids = sampler_output.sampled_token_ids
        max_gen_len = sampled_token_ids.shape[-1]

        fix_req_ids = None
        fix_sampled_token_ids = None
        fix_draft_token_ids = None
        fix_draft_req_ids = self.last_sampled_req_ids
        is_output_valid = False
        if self.speculative_config:
            if max_gen_len == 1:
                valid_sampled_token_ids = sampled_token_ids.tolist()
            else:
                # Includes spec decode tokens.
                valid_sampled_token_ids = self.rejection_sampler.parse_output(
                    sampled_token_ids,
                    self.input_batch.vocab_size,
                )
                self.last_sampler_host_tokens = None
                self.last_sampled_token_ids = None
            is_output_valid = True
        else:
            # No spec decode tokens.
            fix_req_ids = self.last_sampled_req_ids
            if self.last_sampler_host_tokens != None:
                self.last_sampler_event.synchronize()
                fix_sampled_token_ids = self.last_sampler_host_tokens.tolist()
                for req_idx, start_idx, end_idx in self.token_ids_cpu_fix_recode:
                    self.input_batch.token_ids_cpu[req_idx, start_idx:end_idx] = fix_sampled_token_ids[req_idx]
            for req_idx, req_id in enumerate(fix_req_ids):
                if req_id in self.requests:
                    req_state = self.requests[req_id]
                    token_idx = self.last_sampled_token_lens[req_idx]
                    req_state.output_token_ids[token_idx] = fix_sampled_token_ids[req_idx][0]
            self.last_sampler_host_tokens = sampled_token_ids.to('cpu', non_blocking=True)
            self.last_sampler_event.record()
            self.last_sampled_token_ids = sampled_token_ids
            valid_sampled_token_ids = np.ones(sampled_token_ids.shape, dtype=int).tolist()
            
        # Mask out the sampled tokens that should not be sampled.
        for i in discard_sampled_tokens_req_indices:
            valid_sampled_token_ids[i].clear()

        # Cache the sampled tokens in the model runner, so that the scheduler
        # doesn't need to send them back.
        # NOTE(woosuk): As an exception, when using PP, the scheduler sends
        # the sampled tokens back, because there's no direct communication
        # between the first-stage worker and the last-stage worker.
        self.token_ids_cpu_fix_recode.clear()
        self.last_sampled_req_ids = []
        self.last_sampled_token_lens = []
        for req_idx, sampled_ids in enumerate(valid_sampled_token_ids):
            if not sampled_ids:
                continue

            start_idx = self.input_batch.num_tokens_no_spec[req_idx]
            end_idx = start_idx + len(sampled_ids)
            assert end_idx <= self.max_model_len, (
                "Sampled token IDs exceed the max model length. "
                f"Total number of tokens: {end_idx} > max_model_len: "
                f"{self.max_model_len}")

            self.input_batch.token_ids_cpu[req_idx,
                                            start_idx:end_idx] = sampled_ids
            self.token_ids_cpu_fix_recode.append([req_idx, start_idx, end_idx])
            self.input_batch.num_tokens_no_spec[req_idx] = end_idx
            self.input_batch.num_tokens[req_idx] = end_idx
            req_id = self.input_batch.req_ids[req_idx]
            if req_id in self.requests:
                req_state = self.requests[req_id]
                self.last_sampled_req_ids.append(req_id)
                self.last_sampled_token_lens.append(len(req_state.output_token_ids))
                req_state.output_token_ids.extend(sampled_ids)
                
        if not self.speculative_config:
            # Speculative decoding is not enabled.
            spec_token_ids = None
            fix_draft_req_ids = None
        else:
            if self.last_draft_host_tokens is not None:
                self.last_draft_event.synchronize()
                fix_draft_token_ids = self.last_draft_host_tokens.tolist()
            spec_token_ids = self.propose_draft_token_ids(
                scheduler_output,
                valid_sampled_token_ids,
                sampling_metadata,
                hidden_states,
                sample_hidden_states,
                aux_hidden_states,
                spec_decode_metadata,
                attn_metadata,
            )

        # Clear KVConnector state after all KVs are generated.
        if has_kv_transfer_group():
            get_kv_transfer_group().clear_connector_metadata()

        self.eplb_step()

        model_output = ZeroV1ModelRunnerOutput(
            req_ids=self.input_batch.req_ids,
            req_id_to_index=self.input_batch.req_id_to_index,
            sampled_token_ids=valid_sampled_token_ids,
            spec_token_ids=spec_token_ids,
            logprobs=logprobs_lists,
            prompt_logprobs_dict=prompt_logprobs_dict,
            pooler_output=[],
            finished_sending=finished_sending,
            finished_recving=finished_recving,
            num_nans_in_logits=num_nans_in_logits,
            fix_req_ids = fix_req_ids,
            fix_sampled_token_ids = fix_sampled_token_ids,
            fix_draft_tokens_ids = fix_draft_token_ids, 
            fix_draft_req_ids = fix_draft_req_ids,
            is_output_valid=is_output_valid
        )
        return model_output