xpu_model_runner.py 24.9 KB
Newer Older
1
2
3
import dataclasses
import time
import weakref
4
from collections import defaultdict
5
from dataclasses import dataclass
6
7
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple,
                    Type, TypeVar)
8
9
10
11
12

import torch
import torch.nn as nn

from vllm.attention import get_attn_backend
13
from vllm.config import VllmConfig
14
from vllm.distributed import get_pp_group
15
from vllm.forward_context import set_forward_context
16
from vllm.inputs import INPUT_REGISTRY, InputRegistry
17
from vllm.logger import init_logger
18
from vllm.model_executor import SamplingMetadataCache
19
from vllm.model_executor.layers.sampler import SamplerOutput
20
from vllm.model_executor.model_loader import get_model
21
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
22
                             MultiModalKwargs, MultiModalPlaceholderMap,
23
                             MultiModalRegistry)
24
from vllm.sampling_params import SamplingParams
25
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
26
from vllm.utils import DeviceMemoryProfiler, make_tensor_with_pad
27
from vllm.worker.model_runner import AttentionMetadata, SamplingMetadata
28
from vllm.worker.model_runner_base import (
29
    ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
30
31
32
33
34
35
36
    _add_attn_metadata_broadcastable_dict,
    _add_sampling_metadata_broadcastable_dict,
    _init_attn_metadata_from_tensor_dict,
    _init_sampling_metadata_from_tensor_dict)

if TYPE_CHECKING:
    from vllm.attention.backends.abstract import AttentionBackend
37
38
39
40
41

logger = init_logger(__name__)

_PAD_SLOT_ID = -1

42
43
TModelInputForXPU = TypeVar('TModelInputForXPU', bound="ModelInputForXPU")

44

45
46
47
48
49
50
51
52
@dataclass(frozen=True)
class ModelInputForXPU(ModelRunnerInputBase):
    """
    Used by the NeuronModelRunner.
    """
    input_tokens: Optional[torch.Tensor] = None
    input_positions: Optional[torch.Tensor] = None
    attn_metadata: Optional["AttentionMetadata"] = None
53
    multi_modal_kwargs: Optional[BatchedTensorInputs] = None
54
55
56
    virtual_engine: Optional[int] = None
    seq_lens: Optional[List[int]] = None
    query_lens: Optional[List[int]] = None
57
    async_callback: Optional[Callable] = None
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85

    def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
        tensor_dict = {
            "input_tokens": self.input_tokens,
            "input_positions": self.input_positions,
        }
        _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)

        return tensor_dict

    @classmethod
    def from_broadcasted_tensor_dict(
        cls: Type[TModelInputForXPU],
        tensor_dict: Dict[str, Any],
        attn_backend: Optional["AttentionBackend"] = None,
    ) -> TModelInputForXPU:
        if attn_backend is not None:
            tensor_dict = _init_attn_metadata_from_tensor_dict(
                attn_backend, tensor_dict)
        return cls(**tensor_dict)


@dataclass(frozen=True)
class ModelInputForXPUWithSamplingMetadata(ModelInputForXPU):
    """
    Used by the ModelRunner.
    """
    sampling_metadata: Optional["SamplingMetadata"] = None
86

87
    def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
88
89
90
91
92
93
94
95
96
97
98
        tensor_dict = {
            "input_tokens": self.input_tokens,
            "input_positions": self.input_positions,
        }
        _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
        _add_sampling_metadata_broadcastable_dict(tensor_dict,
                                                  self.sampling_metadata)
        return tensor_dict

    @classmethod
    def from_broadcasted_tensor_dict(
99
        cls,
100
101
        tensor_dict: Dict[str, Any],
        attn_backend: Optional["AttentionBackend"] = None,
102
    ) -> "ModelInputForXPUWithSamplingMetadata":
103
104
105
106
107
108
109
        tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict)
        if attn_backend is not None:
            tensor_dict = _init_attn_metadata_from_tensor_dict(
                attn_backend, tensor_dict)
        return cls(**tensor_dict)


110
111
112
113
114
115
116
117
118
119
120
121
122
class ModelInputForXPUBuilder(ModelRunnerInputBuilderBase[ModelInputForXPU]):

    def __init__(self,
                 runner: "XPUModelRunner",
                 finished_requests_ids: Optional[List[str]] = None) -> None:
        super().__init__()
        self.runner = runner
        self.model_input_cls = self.runner._model_input_cls
        self.attn_backend = self.runner.attn_backend
        self.sliding_window = self.runner.sliding_window
        self.block_size = self.runner.block_size
        self.device = self.runner.device

123
124
125
126
    def prepare(self,
                finished_requests_ids: Optional[List[str]] = None) -> None:
        self.seq_group_metadata_list: List[SequenceGroupMetadata] = []

127
128
129
130
131
132
133
134
135
136
137
138
139
140
    def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):
        self.seq_group_metadata_list.append(seq_group_metadata)

    def build(self) -> ModelInputForXPU:
        is_prompt = self.seq_group_metadata_list[0].is_prompt
        # Prepare input tensors.
        if is_prompt:
            (input_tokens, input_positions, attn_metadata, seq_lens,
             multi_modal_kwargs) = self._prepare_prompt(
                 self.seq_group_metadata_list)
        else:
            (input_tokens, input_positions,
             attn_metadata) = self._prepare_decode(
                 self.seq_group_metadata_list)
141
            seq_lens = None
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
            multi_modal_kwargs = None

        return self.model_input_cls(
            input_tokens=input_tokens,
            input_positions=input_positions,
            attn_metadata=attn_metadata,
            multi_modal_kwargs=multi_modal_kwargs,
            seq_lens=seq_lens,
            query_lens=seq_lens,
        )

    def _prepare_prompt(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
    ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int],
               BatchedTensorInputs]:
        assert len(seq_group_metadata_list) > 0
        input_tokens: List[int] = []
        input_positions: List[int] = []
        slot_mapping: List[int] = []
        seq_lens: List[int] = []
163
        multi_modal_kwargs_list: List[MultiModalKwargs] = []
164
165
166
        multi_modal_placeholder_maps: Dict[
            str,
            MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184

        for seq_group_metadata in seq_group_metadata_list:
            assert seq_group_metadata.is_prompt
            seq_ids = list(seq_group_metadata.seq_data.keys())
            assert len(seq_ids) == 1
            seq_id = seq_ids[0]

            seq_data = seq_group_metadata.seq_data[seq_id]
            prompt_tokens = seq_data.get_token_ids()
            computed_len = seq_data.get_num_computed_tokens()
            seq_len = len(prompt_tokens)

            seq_lens.append(seq_len)  # Prompt token num
            input_tokens.extend(prompt_tokens)  # Token ids

            # Token position ids
            # NOTE(woosuk): Here we assume that the first token in the prompt
            # is always the first token in the sequence.
185
186
187
188
189
190
191
192
193
            positions_range = range(computed_len, seq_len)
            input_positions.extend(list(positions_range))

            if seq_group_metadata.multi_modal_data:
                # NOTE: mm_data only includes the subset of multi-modal items
                # that intersect with the current prefill positions.
                mm_data, placeholder_maps = MultiModalPlaceholderMap \
                    .from_seq_group(seq_group_metadata, positions_range)

194
195
196
197
198
199
200
201
202
203
                if self.runner.mm_registry.has_processor(
                        self.runner.model_config):
                    mm_kwargs = mm_data
                else:
                    mm_kwargs = self.runner.multi_modal_input_mapper(
                        mm_data,
                        seq_group_metadata.mm_processor_kwargs,
                    )

                multi_modal_kwargs_list.append(mm_kwargs)
204
205
206
207

                for modality, placeholder_map in placeholder_maps.items():
                    multi_modal_placeholder_maps[modality].extend(
                        placeholder_map)
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247

            if seq_group_metadata.block_tables is None:
                # During memory profiling, the block tables are not initialized
                # yet. In this case, we just use a dummy slot mapping.
                slot_mapping.extend([_PAD_SLOT_ID] * seq_len)
                continue

            # Compute the slot mapping.
            block_table = seq_group_metadata.block_tables[seq_id]
            # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
            # where start_idx is max(0, seq_len - sliding_window).
            # For example, if the prompt len is 10, sliding window is 8, and
            # block size is 4, the first two tokens are masked and the slot
            # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
            start_idx = 0
            if self.sliding_window is not None:
                start_idx = max(0, seq_len - self.sliding_window)

            for i in range(computed_len, seq_len):
                if i < start_idx:
                    slot_mapping.append(_PAD_SLOT_ID)
                    continue

                block_number = block_table[i //
                                           self.block_size]  # type: ignore
                block_offset = i % self.block_size  # type: ignore
                slot = block_number * self.block_size + block_offset
                slot_mapping.append(slot)

        num_prompt_tokens = len(input_tokens)

        input_tokens = torch.tensor(input_tokens,
                                    dtype=torch.long,
                                    device=self.device)  # type: ignore
        input_positions = torch.tensor(input_positions,
                                       dtype=torch.long,
                                       device=self.device)  # type: ignore
        slot_mapping = torch.tensor(slot_mapping,
                                    dtype=torch.long,
                                    device=self.device)  # type: ignore
248
249
250
251
252
        placeholder_index_maps = {
            modality: placeholder_map.index_map()
            for modality, placeholder_map in
            multi_modal_placeholder_maps.items()
        }
253
254
255
256
257
258
259
260
261
262

        max_seqlen = max(seq_lens)
        tmp = [0]
        tmp.extend(seq_lens)
        seqlen = torch.tensor(tmp)
        seqlen_q = torch.cumsum(seqlen, dim=0).to(device=self.device)

        attn_metadata = self.attn_backend.make_metadata(
            is_prompt=True,
            slot_mapping=slot_mapping,
263
            multi_modal_placeholder_index_maps=placeholder_index_maps,
264
            enable_kv_scales_calculation=False,
265
266
267
268
269
270
271
272
273
274
275
            seq_lens=seq_lens,
            seqlen_q=seqlen_q,
            max_seqlen=max_seqlen,
            seq_lens_tensor=torch.tensor([]),
            max_decode_seq_len=0,
            num_prefills=len(seq_lens),
            num_prefill_tokens=num_prompt_tokens,
            num_decode_tokens=0,
            block_tables=torch.tensor([], device=self.device, dtype=torch.int),
        )

276
        multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list)
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347

        return (input_tokens, input_positions, attn_metadata, seq_lens,
                multi_modal_kwargs)

    def _prepare_decode(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
    ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata]:
        assert len(seq_group_metadata_list) > 0
        input_tokens: List[int] = []
        input_positions: List[int] = []
        slot_mapping: List[int] = []
        seq_lens: List[int] = []
        block_tables: List[List[int]] = []

        for seq_group_metadata in seq_group_metadata_list:
            assert not seq_group_metadata.is_prompt
            assert seq_group_metadata.token_chunk_size == 1

            seq_ids = list(seq_group_metadata.seq_data.keys())

            for seq_id in seq_ids:
                seq_data = seq_group_metadata.seq_data[seq_id]
                generation_token = seq_data.get_last_token_id()
                input_tokens.append(generation_token)

                seq_len = seq_data.get_len()
                position = seq_len - 1
                input_positions.append(position)

                seq_len = seq_len if self.sliding_window is None else min(
                    seq_len, self.sliding_window)
                seq_lens.append(seq_len)

                block_table = seq_group_metadata.block_tables[seq_id]
                block_number = block_table[position // self.block_size]
                block_offset = position % self.block_size
                slot = block_number * self.block_size + block_offset
                slot_mapping.append(slot)

                if self.sliding_window is not None:
                    sliding_window_blocks = (self.sliding_window //
                                             self.block_size)
                    block_table = block_table[-sliding_window_blocks:]
                block_tables.append(block_table)

        max_decode_seq_len = max(seq_lens)

        input_tokens = torch.tensor(input_tokens,
                                    dtype=torch.long,
                                    device=self.device)
        input_positions = torch.tensor(input_positions,
                                       dtype=torch.long,
                                       device=self.device)
        slot_mapping = torch.tensor(slot_mapping,
                                    dtype=torch.long,
                                    device=self.device)
        seq_lens_tensor = torch.tensor(seq_lens,
                                       dtype=torch.int,
                                       device=self.device)

        block_tables = make_tensor_with_pad(
            block_tables,
            pad=0,
            dtype=torch.int,
            device=self.device,
        )

        attn_metadata = self.attn_backend.make_metadata(
            is_prompt=False,
            slot_mapping=slot_mapping,
348
            multi_modal_placeholder_index_maps=None,
349
            enable_kv_scales_calculation=False,
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
            seq_lens=seq_lens,
            seqlen_q=torch.tensor([]),
            max_seqlen=0,
            seq_lens_tensor=seq_lens_tensor,
            max_decode_seq_len=max_decode_seq_len,
            num_prefill_tokens=0,
            num_decode_tokens=len(input_tokens),
            num_prefills=0,
            block_tables=block_tables,
        )
        return (
            input_tokens,
            input_positions,
            attn_metadata,
        )


class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
    _model_input_cls: Type[ModelInputForXPUWithSamplingMetadata] = (
        ModelInputForXPUWithSamplingMetadata)
    _builder_cls: Type[ModelInputForXPUBuilder] = ModelInputForXPUBuilder
371
372
373

    def __init__(
        self,
374
        vllm_config: VllmConfig,
375
376
        kv_cache_dtype: Optional[str] = "auto",
        is_driver_worker: bool = False,
377
        return_hidden_states: bool = False,
378
379
        input_registry: InputRegistry = INPUT_REGISTRY,
        mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
380
    ):
381
382
383
384

        ModelRunnerBase.__init__(self, vllm_config=vllm_config)
        model_config = self.model_config
        cache_config = self.cache_config
385
        self.is_driver_worker = is_driver_worker
386
        self.return_hidden_states = return_hidden_states
387
388
389
390

        self.device = self.device_config.device

        self.kv_cache_dtype = kv_cache_dtype
391
        self.sliding_window = model_config.get_sliding_window()
392
393
394
395
396
397
398
        self.block_size = cache_config.block_size

        self.attn_backend = get_attn_backend(
            self.model_config.get_head_size(),
            self.model_config.dtype,
            self.kv_cache_dtype,
            self.block_size,
399
            self.model_config.is_attention_free,
400
401
        )

402
        # Multi-modal data support
403
404
405
406
        self.input_registry = input_registry
        self.mm_registry = mm_registry
        self.multi_modal_input_mapper = mm_registry \
            .create_input_mapper(model_config)
407
        self.mm_registry.init_mm_limits_per_prompt(self.model_config)
408

409
410
411
        # Lazy initialization.
        self.model: nn.Module  # Set after init_Model

412
413
414
415
        self.sampling_metadata_cache: SamplingMetadataCache = \
              SamplingMetadataCache() \
                if self.parallel_config.pipeline_parallel_size == 1 else None

416
417
        self.builder = self._builder_cls(weakref.proxy(self))

418
    def load_model(self) -> None:
419
        with DeviceMemoryProfiler() as m:
420
            self.model = get_model(vllm_config=self.vllm_config)
421
422
423
424
425

        self.model_memory_usage = m.consumed_memory
        logger.info("Loading model weights took %.4f GB",
                    self.model_memory_usage / float(2**30))

426
427
428
    def get_model(self) -> nn.Module:
        return self.model

429
430
431
432
433
434
435
436
437
438
439
440
441
442
    @property
    def vocab_size(self) -> int:
        return self.model_config.get_vocab_size()

    @torch.inference_mode()
    def profile_run(self) -> None:
        # Enable top-k sampling to reflect the accurate memory usage.
        sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
        max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
        max_num_seqs = self.scheduler_config.max_num_seqs

        # Profile memory usage with max_num_sequences sequences and the total
        # number of tokens equal to max_num_batched_tokens.
        seqs: List[SequenceGroupMetadata] = []
443
444
        # Additional GPU memory may be needed for multi-modal encoding, which
        # needs to be accounted for when calculating the GPU blocks for
445
446
447
448
        # vLLM blocker manager.
        # To exercise the worst scenario for GPU memory consumption,
        # the number of seqs (batch_size) is chosen to maximize the number
        # of images processed.
449
450
        max_mm_tokens = self.mm_registry.get_max_multimodal_tokens(
            self.model_config)
451
        if max_mm_tokens > 0:
452
453
454
455
456
457
458
459
460
461
            max_num_seqs_orig = max_num_seqs
            max_num_seqs = min(max_num_seqs,
                               max_num_batched_tokens // max_mm_tokens)
            if max_num_seqs < 1:
                expr = (f"min({max_num_seqs_orig}, "
                        f"{max_num_batched_tokens} // {max_mm_tokens})")
                logger.warning(
                    "Computed max_num_seqs (%s) to be less than 1. "
                    "Setting it to the minimum value of 1.", expr)
                max_num_seqs = 1
462

463
        batch_size = 0
464
465
466
        for group_id in range(max_num_seqs):
            seq_len = (max_num_batched_tokens // max_num_seqs +
                       (group_id < max_num_batched_tokens % max_num_seqs))
467
            batch_size += seq_len
468

469
            dummy_data = self.input_registry \
470
471
472
                .dummy_data_for_profiling(self.model_config,
                                          seq_len,
                                          self.mm_registry)
473

474
475
476
            seq = SequenceGroupMetadata(
                request_id=str(group_id),
                is_prompt=True,
477
                seq_data={group_id: dummy_data.seq_data},
478
479
480
                sampling_params=sampling_params,
                block_tables=None,
                lora_request=None,
481
482
                multi_modal_data=dummy_data.multi_modal_data,
                multi_modal_placeholders=dummy_data.multi_modal_placeholders)
483
484
485
486
            seqs.append(seq)

        # Run the model with the dummy inputs.
        num_layers = self.model_config.get_num_layers(self.parallel_config)
487
488
489
490
491
492
493
        # use an empty tensor instead of `None`` to force Dynamo to pass
        # it by reference, rather by specializing on the value ``None``.
        # the `dtype` argument does not matter, and we use `float32` as
        # a placeholder (it has wide hardware support).
        kv_caches = [
            torch.tensor([], dtype=torch.float32, device=self.device)
        ] * num_layers
494
495
496
        finished_requests_ids = [seq.request_id for seq in seqs]
        model_input = self.prepare_model_input(
            seqs, finished_requests_ids=finished_requests_ids)
497
498
499
500
501
502
503
        intermediate_tensors = None
        if not get_pp_group().is_first_rank:
            intermediate_tensors = self.model.make_empty_intermediate_tensors(
                batch_size=batch_size,
                dtype=self.model_config.dtype,
                device=self.device)
        self.execute_model(model_input, kv_caches, intermediate_tensors)
504
505
506
        torch.xpu.synchronize()
        return

507
    def make_model_input_from_broadcasted_tensor_dict(
Mor Zusman's avatar
Mor Zusman committed
508
            self,
509
510
511
512
513
514
515
            tensor_dict: Dict[str,
                              Any]) -> ModelInputForXPUWithSamplingMetadata:
        return (
            ModelInputForXPUWithSamplingMetadata.from_broadcasted_tensor_dict(
                tensor_dict,
                attn_backend=self.attn_backend,
            ))
516

517
    def _prepare_model_input_tensors(
518
519
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
520
521
522
523
524
525
526
        finished_requests_ids: Optional[List[str]] = None
    ) -> ModelInputForXPUWithSamplingMetadata:
        """Helper method to prepare the model input based on a given sequence
        group. Prepares metadata needed for the base model forward pass but not
        metadata for possible additional steps, e.g., sampling.

        """
527
528
        builder = self.builder
        builder.prepare(finished_requests_ids)
529
        for seq_group_metadata in seq_group_metadata_list:
530
            builder.add_seq_group(seq_group_metadata)
531

532
        return builder.build()  # type: ignore
533

534
535
536
537
538
539
540
541
542
543
544
545
546
547
    def prepare_model_input(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
        virtual_engine: int = 0,
        finished_requests_ids: Optional[List[str]] = None
    ) -> ModelInputForXPUWithSamplingMetadata:
        """Prepare the model input based on a given sequence group, including
        metadata for the sampling step.

        """
        model_input = self._prepare_model_input_tensors(
            seq_group_metadata_list, finished_requests_ids)
        # Sampling metadata is only required for the final pp group
        generators = self.get_generators(finished_requests_ids)
548
549
550
551
552
553
554
555
        sampling_metadata = SamplingMetadata.prepare(
            seq_group_metadata_list,
            model_input.seq_lens,
            model_input.query_lens,
            self.device,
            pin_memory=False,
            generators=generators,
            cache=self.sampling_metadata_cache)
556
557
558
559

        return dataclasses.replace(model_input,
                                   sampling_metadata=sampling_metadata,
                                   virtual_engine=virtual_engine)
560
561
562
563

    @torch.inference_mode()
    def execute_model(
        self,
564
        model_input: ModelInputForXPUWithSamplingMetadata,
565
        kv_caches: List[torch.Tensor],
566
        intermediate_tensors: Optional[IntermediateTensors] = None,
567
568
569
570
571
572
        num_steps: int = 1,
    ) -> Optional[List[SamplerOutput]]:
        if num_steps > 1:
            raise ValueError(
                "XPUModelRunner does not support multi-step execution.")

573
        model_executable = self.model
574
575
576
        if (self.observability_config is not None
                and self.observability_config.collect_model_forward_time):
            model_forward_start_time = time.time()
577
578
579
580
581
582
583
584
585
586
587
        with set_forward_context(model_input.attn_metadata, self.vllm_config,
                                 model_input.virtual_engine):
            hidden_or_intermediate_states = model_executable(
                input_ids=model_input.input_tokens,
                positions=model_input.input_positions,
                kv_caches=kv_caches,
                attn_metadata=model_input.attn_metadata,
                intermediate_tensors=intermediate_tensors,
                **MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs
                                             or {},
                                             device=self.device))
588
589
590
591
        # Compute the logits in the last pipeline stage.
        if not get_pp_group().is_last_rank:
            return hidden_or_intermediate_states

592
593
594
        if (self.observability_config is not None
                and self.observability_config.collect_model_forward_time):
            model_forward_end_time = time.time()
595
596

        # Compute the logits.
597
        logits = self.model.compute_logits(hidden_or_intermediate_states,
598
                                           model_input.sampling_metadata)
599
600
601

        # Only perform sampling in the driver worker.
        if not self.is_driver_worker:
602
            return []
603

604
605
606
        if model_input.async_callback is not None:
            model_input.async_callback()

607
        # Sample the next token.
608
        output: SamplerOutput = self.model.sample(
609
            logits=logits,
610
            sampling_metadata=model_input.sampling_metadata,
611
        )
612
613
614
615
616
617
618
619
620
621
        if (self.observability_config is not None
                and self.observability_config.collect_model_forward_time
                and output is not None):
            model_forward_time = (model_forward_end_time -
                                  model_forward_start_time)
            # If there are multiple workers, we are still tracking the latency
            # from the start time of the driver worker to the end time of the
            # driver worker. The model forward time will then end up covering
            # the communication time as well.
            output.model_forward_time = model_forward_time
622

623
        return [output]