cpu_model_runner.py 27.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
import dataclasses
import weakref
6
from collections import defaultdict
7
from dataclasses import dataclass
8
9
from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Set, Type,
                    TypeVar, Union)
10
11

import torch
12
from torch import nn
13
14

from vllm.attention import AttentionMetadata, get_attn_backend
15
from vllm.config import VllmConfig
youkaichao's avatar
youkaichao committed
16
from vllm.forward_context import set_forward_context
17
from vllm.logger import init_logger
18
19
20
from vllm.lora.layers import LoRAMapping
from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
21
from vllm.model_executor import SamplingMetadata
22
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
23
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
24
from vllm.model_executor.model_loader import get_model
25
from vllm.model_executor.models import supports_lora, supports_multimodal
26
27
from vllm.multimodal import (BatchedTensorInputs, MultiModalKwargs,
                             MultiModalPlaceholderMap)
28
29
from vllm.sequence import (IntermediateTensors, SequenceData,
                           SequenceGroupMetadata)
30
from vllm.worker.model_runner_base import (
31
    ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
32
33
34
35
36
37
38
    _add_attn_metadata_broadcastable_dict,
    _add_sampling_metadata_broadcastable_dict,
    _init_attn_metadata_from_tensor_dict,
    _init_sampling_metadata_from_tensor_dict)

if TYPE_CHECKING:
    from vllm.attention.backends.abstract import AttentionBackend
39
40
41

logger = init_logger(__name__)

42
TModelInputForCPU = TypeVar('TModelInputForCPU', bound="ModelInputForCPU")
43
44
45
_PAD_SLOT_ID = -1


46
@dataclass(frozen=True)
47
class ModelInputForCPU(ModelRunnerInputBase):
48
    """
49
    Base class contains metadata needed for the base model forward pass on CPU
50
51
52
    """
    input_tokens: Optional[torch.Tensor] = None
    input_positions: Optional[torch.Tensor] = None
53
    token_type_ids: Optional[torch.Tensor] = None
54
    attn_metadata: Optional["AttentionMetadata"] = None
55
    multi_modal_kwargs: Optional[BatchedTensorInputs] = None
56
    virtual_engine: Optional[int] = None
57
58
    seq_lens: Optional[List[int]] = None
    query_lens: Optional[List[int]] = None
59
60
    lora_mapping: Optional["LoRAMapping"] = None
    lora_requests: Optional[Set[LoRARequest]] = None
61
62
63
64
65
66

    def as_broadcastable_tensor_dict(
            self) -> Dict[str, Union[int, torch.Tensor]]:
        tensor_dict = {
            "input_tokens": self.input_tokens,
            "input_positions": self.input_positions,
67
            "token_type_ids": self.token_type_ids,
68
            "multi_modal_kwargs": self.multi_modal_kwargs,
69
70
            "lora_requests": self.lora_requests,
            "lora_mapping": self.lora_mapping,
71
72
        }
        _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
73

74
75
76
77
        return tensor_dict

    @classmethod
    def from_broadcasted_tensor_dict(
78
        cls: Type[TModelInputForCPU],
79
80
        tensor_dict: Dict[str, Any],
        attn_backend: Optional["AttentionBackend"] = None
81
    ) -> TModelInputForCPU:
82
83
84
85
86
87
        if attn_backend is not None:
            tensor_dict = _init_attn_metadata_from_tensor_dict(
                attn_backend, tensor_dict)
        return cls(**tensor_dict)


88
89
90
91
92
93
@dataclass(frozen=True)
class ModelInputForCPUWithSamplingMetadata(ModelInputForCPU):
    """
    Used by the ModelRunner.
    """
    sampling_metadata: Optional["SamplingMetadata"] = None
94
    is_prompt: Optional[bool] = None
95

96
97
98
99
    def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
        tensor_dict = {
            "input_tokens": self.input_tokens,
            "input_positions": self.input_positions,
100
            "token_type_ids": self.token_type_ids,
101
            "multi_modal_kwargs": self.multi_modal_kwargs,
102
103
104
105
106
        }
        _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
        _add_sampling_metadata_broadcastable_dict(tensor_dict,
                                                  self.sampling_metadata)
        return tensor_dict
107

108
109
110
111
112
113
114
115
116
117
118
    @classmethod
    def from_broadcasted_tensor_dict(
        cls,
        tensor_dict: Dict[str, Any],
        attn_backend: Optional["AttentionBackend"] = None,
    ) -> "ModelInputForCPUWithSamplingMetadata":
        tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict)
        if attn_backend is not None:
            tensor_dict = _init_attn_metadata_from_tensor_dict(
                attn_backend, tensor_dict)
        return cls(**tensor_dict)
119
120


121
class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
122

123
124
125
126
127
    class ModelInputData:

        def __init__(self, use_mrope: bool):
            self.use_mrope = use_mrope
            self.input_tokens: List[int] = []
128
            self.input_positions: List[int] = []
129
            self.token_type_ids: Optional[List[int]] = []
130
131
132
133
134
135
136
137
138
139
140
141
142
            self.seq_lens: List[int] = []
            self.query_lens: List[int] = []
            self.prefill_block_tables: List[List[int]] = []
            self.decode_block_tables: List[List[int]] = []
            self.max_decode_seq_len: int = 0
            self.num_prefills: int = 0
            self.num_prefill_tokens: int = 0
            self.num_decode_tokens: int = 0
            self.slot_mapping: List[int] = []
            self.multi_modal_inputs_list: List[MultiModalKwargs] = []
            self.multi_modal_placeholder_maps: Dict[
                str, MultiModalPlaceholderMap] = defaultdict(
                    MultiModalPlaceholderMap)
143
144
            self.input_mrope_positions: List[List[int]] = [[]
                                                           for _ in range(3)]
145

146
147
148
149
150
    def __init__(self,
                 runner: "CPUModelRunner",
                 finished_requests_ids: Optional[List[str]] = None) -> None:
        super().__init__()
        self.runner = runner
151
152
        self.chunked_prefill = (runner.scheduler_config.chunked_prefill_enabled
                                or runner.cache_config.enable_prefix_caching)
153
154
        self.model_input_cls = self.runner._model_input_cls
        self.attn_backend = self.runner.attn_backend
155
156
157
158
        self.sliding_window = self.runner.sliding_window
        self.block_size = self.runner.block_size
        self.device = self.runner.device
        self.enable_lora = self.runner.lora_config is not None
159
160
161
162
163
164
165
166
        if self.runner.attn_backend is not None:
            # spec decode (e.g. Medusa) does not have atten backend
            attn_backend = self.runner.attn_backend
            self.att_metadata_builder = attn_backend.get_builder_cls()(self)

    def prepare(self,
                finished_requests_ids: Optional[List[str]] = None) -> None:
        self.seq_group_metadata_list: List[SequenceGroupMetadata] = []
167
168
        self.input_data = ModelInputForCPUBuilder.ModelInputData(
            self.runner.model_config.uses_mrope)
169
        self.att_metadata_builder.prepare()
170

171
172
    def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):
        self.seq_group_metadata_list.append(seq_group_metadata)
173

174
175
176
177
    def set_seq_group_list(
            self, seq_group_metadata_list: List[SequenceGroupMetadata]):
        self.seq_group_metadata_list = seq_group_metadata_list

178
    def build(self) -> ModelInputForCPU:
179
180
181
182
183
184
185
186
        self._build_input_data()

        input_data = self.input_data
        input_tokens = torch.tensor(input_data.input_tokens,
                                    dtype=torch.long,
                                    device="cpu")
        input_positions = torch.tensor(
            input_data.input_positions
187
188
            if not any(input_data.input_mrope_positions) else
            input_data.input_mrope_positions,
189
190
            dtype=torch.long,
            device="cpu")
191
192
193
194
        token_type_ids = torch.tensor(input_data.token_type_ids,
                                    dtype=torch.long,
                                    device="cpu") \
                                    if input_data.token_type_ids else None
195
196

        # For multi-modal models
197
        multi_modal_kwargs = None
198
199
200
201
202
203
        if len(input_data.multi_modal_inputs_list) != 0:
            multi_modal_kwargs = MultiModalKwargs.batch(
                input_data.multi_modal_inputs_list)

        attn_metadata = self.att_metadata_builder.build(
            input_data.seq_lens, input_data.query_lens, -1, -1)
204

205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
        is_prompt = (self.seq_group_metadata_list[0].is_prompt
                     if self.seq_group_metadata_list else None)
        # LoRA data.
        lora_requests = set()
        lora_mapping = None
        if self.enable_lora:
            lora_requests = set(seq.lora_request
                                for seq in self.seq_group_metadata_list
                                if seq.lora_request is not None)

            lora_mapping = self._prepare_lora_input(
                self.seq_group_metadata_list, is_prompt)

        return self.model_input_cls(input_tokens=input_tokens,
                                    input_positions=input_positions,
                                    token_type_ids=token_type_ids,
                                    seq_lens=input_data.seq_lens,
                                    query_lens=input_data.query_lens,
                                    attn_metadata=attn_metadata,
                                    multi_modal_kwargs=multi_modal_kwargs,
                                    lora_mapping=lora_mapping,
                                    lora_requests=lora_requests)
227

228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
    def _build_input_data(self):
        for seq_group_metadata in self.seq_group_metadata_list:
            for seq_id, seq_data in seq_group_metadata.seq_data.items():
                if seq_group_metadata.is_prompt:
                    self._compute_prompt_input_tokens(self.input_data,
                                                      seq_group_metadata,
                                                      seq_data, seq_id)
                    if seq_group_metadata.multi_modal_data:
                        self._compute_multi_modal_input(
                            seq_group_metadata, seq_data)
                else:
                    self._compute_decode_input_tokens(self.input_data,
                                                      seq_group_metadata,
                                                      seq_data, seq_id)

    def _compute_decode_input_tokens(self, data: ModelInputData,
                                     seq_group_metadata: SequenceGroupMetadata,
                                     seq_data: SequenceData, seq_id: int):
        """
        Compute decode input tokens, positions, block table and slot mapping.
        """
        block_size = self.runner.block_size

        block_table = seq_group_metadata.block_tables[seq_id]
        seq_len = seq_data.get_len()
        context_len = seq_data.get_num_computed_tokens()

        tokens = seq_data.get_last_token_id()
        token_positions = seq_len - 1
        block_number = block_table[token_positions // block_size]
        block_offset = token_positions % block_size
        slot = block_number * block_size + block_offset

        # For paged_attention kernel
        if self.runner.sliding_window:
            start_idx = max(0, seq_len - self.runner.sliding_window)
            start_block = start_idx // block_size
            start_idx = start_block * block_size
            seq_len = seq_len - start_idx
            block_table = block_table[start_block:]

        # For MRotaryEmbedding
270
        if seq_data.mrope_position_delta is not None:
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
            next_pos = MRotaryEmbedding.get_next_input_positions(
                seq_data.mrope_position_delta,
                context_len,
                seq_len,
            )
            for idx in range(3):
                data.input_mrope_positions[idx].extend(  # type: ignore
                    next_pos[idx])
        else:
            data.input_positions.append(token_positions)  # type: ignore

        # Update fields
        data.input_tokens.append(tokens)
        data.max_decode_seq_len = max(data.max_decode_seq_len, seq_len)
        data.num_decode_tokens += 1
        data.slot_mapping.append(slot)
        data.decode_block_tables.append(block_table)
        data.query_lens.append(1)
        data.seq_lens.append(seq_len)

    def _compute_prompt_input_tokens(self, data: ModelInputData,
                                     seq_group_metadata: SequenceGroupMetadata,
                                     seq_data: SequenceData, seq_id: int):
        """
        Compute prompt input tokens, positions, block table and slot mapping.
        """
        token_chunk_size = seq_group_metadata.token_chunk_size
        block_size = self.runner.block_size

        block_table = seq_group_metadata.block_tables[seq_id]
        seq_len = seq_data.get_len()
        context_len = seq_data.get_num_computed_tokens()
        seq_len = min(seq_len, context_len + token_chunk_size)

        # For prefix caching
        prefix_cache_block_num = len(seq_group_metadata.computed_block_nums)
        if prefix_cache_block_num > 0:
            prefix_cache_len = (prefix_cache_block_num *
                                self.runner.block_size)
            if prefix_cache_len <= context_len:
                # We already passed the cache hit region,
                # so do normal computation.
                pass
            elif context_len < prefix_cache_len < seq_len:
                # Partial hit. Compute the missing part.
                context_len = prefix_cache_len
                token_chunk_size = seq_len - context_len
            elif seq_len <= prefix_cache_len:
                # Full hit. Only compute the last token to avoid
                # erroneous behavior. FIXME: Ideally we should directly
                # mark all tokens as computed in the scheduler and do not
                # schedule this sequence, so this case should not happen.
                context_len = seq_len - 1
                token_chunk_size = 1

        tokens = seq_data.get_token_ids()
        tokens = tokens[context_len:seq_len]
        token_positions = range(context_len, seq_len)
329
        token_types = seq_group_metadata.token_type_ids
330
331
332
333
334
335
336
337
338
339
340
341
342

        # For encoder-only models, the block_table is None,
        # and there is no need to initialize the slot_mapping.
        if block_table is not None:
            slot_mapping = [_PAD_SLOT_ID] * len(token_positions)
            for i, pos in enumerate(token_positions):
                block_number = block_table[pos // block_size]
                block_offset = pos % block_size
                slot = block_number * block_size + block_offset
                slot_mapping[i] = slot
            data.slot_mapping.extend(slot_mapping)

        # The MROPE positions are prepared in _compute_multi_modal_input
343
        data.input_positions.extend(token_positions)
344

345
346
347
        if data.token_type_ids is not None:
            data.token_type_ids.extend(token_types if token_types else [])

348
349
350
351
352
353
354
355
356
357
358
359
360
361
        # Update fields
        data.input_tokens.extend(tokens)
        data.num_prefills += 1
        data.num_prefill_tokens += len(tokens)
        data.query_lens.append(len(tokens))
        data.prefill_block_tables.append(block_table)
        data.seq_lens.append(seq_len)

    def _compute_multi_modal_input(self,
                                   seq_group_metadata: SequenceGroupMetadata,
                                   seq_data: SequenceData):
        computed_len = seq_data.get_num_computed_tokens()
        seq_len = self.input_data.seq_lens[-1]

362
        # NOTE: mm_kwargs only includes the subset of multi-modal items that
363
        # intersect with the current prefill positions.
364
        mm_kwargs, placeholder_maps = MultiModalPlaceholderMap.from_seq_group(
365
            seq_group_metadata, range(computed_len, seq_len))
366

367
        if not mm_kwargs:
368
            return
369

370
        # special processing for mrope position deltas.
371
        if self.runner.model_config.uses_mrope:
372
373
374
            assert not self.chunked_prefill, \
                "MROPE on CPU does not support chunked-prefill."

375
376
            image_grid_thw = mm_kwargs.get("image_grid_thw", None)
            video_grid_thw = mm_kwargs.get("video_grid_thw", None)
377
378
379
380
381
382
383
384
            audio_feature_lengths = mm_kwargs.get("audio_feature_lengths",
                                                  None)
            assert (
                image_grid_thw is not None or video_grid_thw is not None
                or audio_feature_lengths is not None), (
                    "mrope embedding type requires multi-modal input mapper "
                    "returns 'image_grid_thw' or 'video_grid_thw' or "
                    "'audio_feature_lengths'.")
385

Roger Wang's avatar
Roger Wang committed
386
            second_per_grid_ts = mm_kwargs.get("second_per_grid_ts", None)
387
            use_audio_in_video = mm_kwargs.get("use_audio_in_video", False)
388
389
390
391
392
393
            hf_config = self.runner.model_config.hf_config
            token_ids = seq_data.get_token_ids()

            mrope_positions, mrope_position_delta = \
                MRotaryEmbedding.get_input_positions(
                    token_ids,
Roger Wang's avatar
Roger Wang committed
394
                    hf_config=hf_config,
395
396
                    image_grid_thw=image_grid_thw,
                    video_grid_thw=video_grid_thw,
Roger Wang's avatar
Roger Wang committed
397
                    second_per_grid_ts=second_per_grid_ts,
398
                    context_len=computed_len,
399
400
                    audio_feature_lengths=audio_feature_lengths,
                    use_audio_in_video=use_audio_in_video,
401
402
403
                )
            seq_data.mrope_position_delta = mrope_position_delta

404
405
406
            for i in range(3):
                self.input_data.input_mrope_positions[  # type: ignore
                    i].extend(mrope_positions[i])
407

408
409
410
411
        self.input_data.multi_modal_inputs_list.append(mm_kwargs)
        for modality, placeholder_map in placeholder_maps.items():
            self.input_data.multi_modal_placeholder_maps[modality].extend(
                placeholder_map)
412

413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
    def _prepare_lora_input(
            self, seq_group_metadata_list: List[SequenceGroupMetadata],
            is_prefill: bool) -> LoRAMapping:
        index_mapping = []
        prompt_mapping = []
        for seq in seq_group_metadata_list:
            lora_id = seq.lora_int_id
            query_len = seq.token_chunk_size

            index_mapping += [lora_id] * query_len
            prompt_mapping += [lora_id] * (
                query_len if seq.sampling_params
                and seq.sampling_params.prompt_logprobs is not None else 1)

        return LoRAMapping(index_mapping=tuple(index_mapping),
                           prompt_mapping=tuple(prompt_mapping),
                           is_prefill=is_prefill)

431

432
433
434
435
436
437
class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]):
    """
    Helper class for shared methods between CPU model runners.
    """
    _model_input_cls: Type[TModelInputForCPU]
    _builder_cls: Type[ModelInputForCPUBuilder]
438
    builder: ModelInputForCPUBuilder
439
440
441

    def __init__(
        self,
442
        vllm_config: VllmConfig,
443
444
        kv_cache_dtype: Optional[str] = "auto",
        is_driver_worker: bool = False,
445
        return_hidden_states: bool = False,
446
447
448
        *args,
        **kwargs,
    ):
449
450
451
452
        ModelRunnerBase.__init__(self, vllm_config)
        model_config = self.model_config
        cache_config = self.cache_config

453
        self.is_driver_worker = is_driver_worker
454
        self.return_hidden_states = return_hidden_states
455
456

        self.device = self.device_config.device
457
        self.pin_memory = False
458
459
460
461

        self.kv_cache_dtype = kv_cache_dtype
        self.sliding_window = model_config.get_sliding_window()
        self.block_size = cache_config.block_size
462
463
464
465
        num_attn_heads = self.model_config.get_num_attention_heads(
            self.parallel_config)
        needs_attn_backend = (num_attn_heads != 0
                              or self.model_config.is_attention_free)
466
467
468
469
470
        self.attn_backend = get_attn_backend(
            self.model_config.get_head_size(),
            self.model_config.dtype,
            self.kv_cache_dtype,
            self.block_size,
471
            self.model_config.is_attention_free,
Thien Tran's avatar
Thien Tran committed
472
            use_mla=self.model_config.use_mla,
473
        ) if needs_attn_backend else None
474
475
476

        # Lazy initialization.
        self.model: nn.Module  # Set after init_Model
477
478
        # Set after load_model.
        self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None
479
        self.sampler = get_sampler()
480

481
482
483
484
        if hasattr(self, "_builder_cls"):
            # multi-step model runner does not have `_builder_cls`
            self.builder = self._builder_cls(weakref.proxy(self))

485
    def load_model(self) -> None:
486
        self.model = get_model(vllm_config=self.vllm_config)
487

488
489
490
491
492
493
494
495
496
        if self.lora_config:
            assert supports_lora(
                self.model
            ), f"{self.model.__class__.__name__} does not support LoRA yet."

            if supports_multimodal(self.model):
                logger.warning("Regarding multimodal models, vLLM currently "
                               "only supports adding LoRA to language model.")

497
498
            # Use get_text_config() in case of multimodal models
            text_config = self.model_config.hf_config.get_text_config()
499
500
501
502
503
504
505
506
507

            self.lora_manager = LRUCacheWorkerLoRAManager(
                self.scheduler_config.max_num_seqs,
                self.scheduler_config.max_num_batched_tokens,
                self.vocab_size,
                self.lora_config,
                self.device,
                self.model.embedding_modules,
                self.model.embedding_padding_modules,
508
                max_position_embeddings=text_config.max_position_embeddings,
509
510
511
            )
            self.model = self.lora_manager.create_lora_manager(self.model)

512
513
514
    def get_model(self) -> nn.Module:
        return self.model

515
516
517
518
    def _prepare_model_input_tensors(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
        finished_requests_ids: Optional[List[str]] = None
519
    ) -> TModelInputForCPU:
520
521
522
523
524
        """Helper method to prepare the model input based on a given sequence
        group. Prepares metadata needed for the base model forward pass but not
        metadata for possible additional steps, e.g., sampling.

        """
525
526
        self.builder.prepare(finished_requests_ids)
        self.builder.set_seq_group_list(seq_group_metadata_list)
527

528
        return self.builder.build()  # type: ignore
529

530
531
532
533
    @property
    def vocab_size(self) -> int:
        return self.model_config.get_vocab_size()

534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
    def remove_all_loras(self):
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
        self.lora_manager.remove_all_adapters()

    def set_active_loras(self, lora_requests: Set[LoRARequest],
                         lora_mapping: LoRAMapping) -> None:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
        self.lora_manager.set_active_adapters(lora_requests, lora_mapping)

    def add_lora(self, lora_request: LoRARequest) -> bool:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
        return self.lora_manager.add_adapter(lora_request)

    def remove_lora(self, lora_id: int) -> bool:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
        return self.lora_manager.remove_adapter(lora_id)

    def pin_lora(self, lora_id: int) -> bool:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
        return self.lora_manager.pin_adapter(lora_id)

    def list_loras(self) -> Set[int]:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
        return self.lora_manager.list_adapters()

565
566
567
568
569
570
571
572
573
574
575
576
577
578
579

class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]):
    _model_input_cls: Type[ModelInputForCPUWithSamplingMetadata] = (
        ModelInputForCPUWithSamplingMetadata)
    _builder_cls: Type[ModelInputForCPUBuilder] = ModelInputForCPUBuilder

    def make_model_input_from_broadcasted_tensor_dict(
        self,
        tensor_dict: Dict[str, Any],
    ) -> ModelInputForCPUWithSamplingMetadata:
        return ModelInputForCPUWithSamplingMetadata.from_broadcasted_tensor_dict(  # noqa: E501
            tensor_dict,
            attn_backend=self.attn_backend,
        )

580
    def prepare_model_input(
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
        virtual_engine: int = 0,
        finished_requests_ids: Optional[List[str]] = None
    ) -> ModelInputForCPUWithSamplingMetadata:
        """Prepare the model input based on a given sequence group, including
        metadata for the sampling step.

        """
        model_input = self._prepare_model_input_tensors(
            seq_group_metadata_list, finished_requests_ids)
        # Sampling metadata is only required for the final pp group
        generators = self.get_generators(finished_requests_ids)
        sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list,
                                                     model_input.seq_lens,
                                                     model_input.query_lens,
                                                     self.device,
                                                     pin_memory=False,
                                                     generators=generators)

601
602
        is_prompt = (seq_group_metadata_list[0].is_prompt
                     if seq_group_metadata_list else None)
603
604
        return dataclasses.replace(model_input,
                                   sampling_metadata=sampling_metadata,
605
606
                                   virtual_engine=virtual_engine,
                                   is_prompt=is_prompt)
607

608
    @torch.no_grad()
609
610
    def execute_model(
        self,
611
        model_input: ModelInputForCPUWithSamplingMetadata,
612
        kv_caches: List[torch.Tensor],
613
        intermediate_tensors: Optional[IntermediateTensors] = None,
614
        num_steps: int = 1,
615
        previous_hidden_states: Optional[torch.Tensor] = None,
616
617
618
619
620
    ) -> Optional[List[SamplerOutput]]:
        if num_steps > 1:
            raise ValueError(
                "CPU worker does not support multi-step execution.")

621
622
623
624
625
626
        if self.lora_config:
            assert model_input.lora_requests is not None
            assert model_input.lora_mapping is not None
            self.set_active_loras(model_input.lora_requests,
                                  model_input.lora_mapping)

627
        model_executable = self.model
628

629
630
631
        multimodal_kwargs = {}
        if model_input.multi_modal_kwargs is not None:
            multimodal_kwargs = MultiModalKwargs.as_kwargs(
632
633
634
                model_input.multi_modal_kwargs,
                device=self.device,
            )
635
636
637
638
        execute_model_kwargs = {}
        if previous_hidden_states is not None:
            execute_model_kwargs.update(
                {"previous_hidden_states": previous_hidden_states})
639

640
641
        with set_forward_context(model_input.attn_metadata, self.vllm_config,
                                 model_input.virtual_engine):
youkaichao's avatar
youkaichao committed
642
643
644
645
            hidden_states = model_executable(
                input_ids=model_input.input_tokens,
                positions=model_input.input_positions,
                intermediate_tensors=intermediate_tensors,
646
                **execute_model_kwargs,
youkaichao's avatar
youkaichao committed
647
648
                **multimodal_kwargs,
            )
649
650

        # Compute the logits.
651
652
        logits = self.model.compute_logits(hidden_states,
                                           model_input.sampling_metadata)
653
654

        # Only perform sampling in the driver worker.
655
        if not self.is_driver_worker:
656
            return []
657
658

        # Sample the next token.
659
        output = self.sampler(
660
            logits=logits,
661
            sampling_metadata=model_input.sampling_metadata,
662
        )
663
664
665
666
667
        if self.return_hidden_states:
            # we only need to pass hidden states of most recent token
            if model_input.is_prompt:
                output.prefill_hidden_states = hidden_states
            output.hidden_states = hidden_states
668
        return [output]
669
670
671

    def generate_proposals(self, *args, **kwargs):
        return self.model.generate_proposals(*args, **kwargs)