cpu_model_runner.py 21.4 KB
Newer Older
1
2
import dataclasses
import weakref
3
from collections import defaultdict
4
from dataclasses import dataclass
5
6
from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Type, TypeVar,
                    Union)
7
8

import torch
9
from torch import nn
10
11

from vllm.attention import AttentionMetadata, get_attn_backend
12
from vllm.config import VllmConfig
youkaichao's avatar
youkaichao committed
13
from vllm.forward_context import set_forward_context
14
15
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
16
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
17
from vllm.model_executor.layers.sampler import SamplerOutput
18
from vllm.model_executor.model_loader import get_model
19
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
20
                             MultiModalKwargs, MultiModalPlaceholderMap)
21
22
from vllm.sequence import (IntermediateTensors, SequenceData,
                           SequenceGroupMetadata)
23
from vllm.worker.model_runner_base import (
24
    ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
25
26
27
28
29
30
31
    _add_attn_metadata_broadcastable_dict,
    _add_sampling_metadata_broadcastable_dict,
    _init_attn_metadata_from_tensor_dict,
    _init_sampling_metadata_from_tensor_dict)

if TYPE_CHECKING:
    from vllm.attention.backends.abstract import AttentionBackend
32
33
34

logger = init_logger(__name__)

35
TModelInputForCPU = TypeVar('TModelInputForCPU', bound="ModelInputForCPU")
36
37
38
_PAD_SLOT_ID = -1


39
@dataclass(frozen=True)
40
class ModelInputForCPU(ModelRunnerInputBase):
41
    """
42
    Base class contains metadata needed for the base model forward pass on CPU
43
44
45
    """
    input_tokens: Optional[torch.Tensor] = None
    input_positions: Optional[torch.Tensor] = None
46
    token_type_ids: Optional[torch.Tensor] = None
47
    attn_metadata: Optional["AttentionMetadata"] = None
48
    multi_modal_kwargs: Optional[BatchedTensorInputs] = None
49
    virtual_engine: Optional[int] = None
50
51
    seq_lens: Optional[List[int]] = None
    query_lens: Optional[List[int]] = None
52
53
54
55
56
57

    def as_broadcastable_tensor_dict(
            self) -> Dict[str, Union[int, torch.Tensor]]:
        tensor_dict = {
            "input_tokens": self.input_tokens,
            "input_positions": self.input_positions,
58
            "token_type_ids": self.token_type_ids,
59
60
61
            "multi_modal_kwargs": self.multi_modal_kwargs,
        }
        _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
62

63
64
65
66
        return tensor_dict

    @classmethod
    def from_broadcasted_tensor_dict(
67
        cls: Type[TModelInputForCPU],
68
69
        tensor_dict: Dict[str, Any],
        attn_backend: Optional["AttentionBackend"] = None
70
    ) -> TModelInputForCPU:
71
72
73
74
75
76
        if attn_backend is not None:
            tensor_dict = _init_attn_metadata_from_tensor_dict(
                attn_backend, tensor_dict)
        return cls(**tensor_dict)


77
78
79
80
81
82
@dataclass(frozen=True)
class ModelInputForCPUWithSamplingMetadata(ModelInputForCPU):
    """
    Used by the ModelRunner.
    """
    sampling_metadata: Optional["SamplingMetadata"] = None
83

84
85
86
87
    def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
        tensor_dict = {
            "input_tokens": self.input_tokens,
            "input_positions": self.input_positions,
88
            "token_type_ids": self.token_type_ids,
89
            "multi_modal_kwargs": self.multi_modal_kwargs,
90
91
92
93
94
        }
        _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
        _add_sampling_metadata_broadcastable_dict(tensor_dict,
                                                  self.sampling_metadata)
        return tensor_dict
95

96
97
98
99
100
101
102
103
104
105
106
    @classmethod
    def from_broadcasted_tensor_dict(
        cls,
        tensor_dict: Dict[str, Any],
        attn_backend: Optional["AttentionBackend"] = None,
    ) -> "ModelInputForCPUWithSamplingMetadata":
        tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict)
        if attn_backend is not None:
            tensor_dict = _init_attn_metadata_from_tensor_dict(
                attn_backend, tensor_dict)
        return cls(**tensor_dict)
107
108


109
class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
110

111
112
113
114
115
116
117
    class ModelInputData:

        def __init__(self, use_mrope: bool):
            self.use_mrope = use_mrope
            self.input_tokens: List[int] = []
            self.input_positions: Optional[
                List[int]] = [] if not self.use_mrope else None
118
            self.token_type_ids: Optional[List[int]] = []
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
            self.seq_lens: List[int] = []
            self.query_lens: List[int] = []
            self.prefill_block_tables: List[List[int]] = []
            self.decode_block_tables: List[List[int]] = []
            self.max_decode_seq_len: int = 0
            self.num_prefills: int = 0
            self.num_prefill_tokens: int = 0
            self.num_decode_tokens: int = 0
            self.slot_mapping: List[int] = []
            self.multi_modal_inputs_list: List[MultiModalKwargs] = []
            self.multi_modal_placeholder_maps: Dict[
                str, MultiModalPlaceholderMap] = defaultdict(
                    MultiModalPlaceholderMap)
            self.input_mrope_positions: Optional[List[List[int]]] = [
                [] for _ in range(3)
            ] if self.use_mrope else None

136
137
138
139
140
141
    def __init__(self,
                 runner: "CPUModelRunner",
                 finished_requests_ids: Optional[List[str]] = None) -> None:
        super().__init__()
        self.seq_group_metadata_list: List[SequenceGroupMetadata] = []
        self.runner = runner
142
143
144

        self.chunked_prefill = (runner.scheduler_config.chunked_prefill_enabled
                                or runner.cache_config.enable_prefix_caching)
145
146
147
        self.model_input_cls = self.runner._model_input_cls
        self.attn_backend = self.runner.attn_backend
        self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper
148
149
150
151
        self.input_data = ModelInputForCPUBuilder.ModelInputData(
            self.runner.model_config.uses_mrope)
        self.att_metadata_builder = self.runner.attn_backend.get_builder_cls()(
            self)
152

153
154
    def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):
        self.seq_group_metadata_list.append(seq_group_metadata)
155

156
157
158
159
    def set_seq_group_list(
            self, seq_group_metadata_list: List[SequenceGroupMetadata]):
        self.seq_group_metadata_list = seq_group_metadata_list

160
    def build(self) -> ModelInputForCPU:
161
162
163
164
165
166
167
168
169
170
171
        self._build_input_data()

        input_data = self.input_data
        input_tokens = torch.tensor(input_data.input_tokens,
                                    dtype=torch.long,
                                    device="cpu")
        input_positions = torch.tensor(
            input_data.input_positions
            if not input_data.use_mrope else input_data.input_mrope_positions,
            dtype=torch.long,
            device="cpu")
172
173
174
175
        token_type_ids = torch.tensor(input_data.token_type_ids,
                                    dtype=torch.long,
                                    device="cpu") \
                                    if input_data.token_type_ids else None
176
177

        # For multi-modal models
178
        multi_modal_kwargs = None
179
180
181
182
183
184
        if len(input_data.multi_modal_inputs_list) != 0:
            multi_modal_kwargs = MultiModalKwargs.batch(
                input_data.multi_modal_inputs_list)

        attn_metadata = self.att_metadata_builder.build(
            input_data.seq_lens, input_data.query_lens, -1, -1)
185
186
187
188

        return self.model_input_cls(
            input_tokens=input_tokens,
            input_positions=input_positions,
189
            token_type_ids=token_type_ids,
190
191
            seq_lens=input_data.seq_lens,
            query_lens=input_data.query_lens,
192
193
194
            attn_metadata=attn_metadata,
            multi_modal_kwargs=multi_modal_kwargs,
        )
195

196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
    def _build_input_data(self):
        for seq_group_metadata in self.seq_group_metadata_list:
            for seq_id, seq_data in seq_group_metadata.seq_data.items():
                if seq_group_metadata.is_prompt:
                    self._compute_prompt_input_tokens(self.input_data,
                                                      seq_group_metadata,
                                                      seq_data, seq_id)
                    if seq_group_metadata.multi_modal_data:
                        self._compute_multi_modal_input(
                            seq_group_metadata, seq_data)
                else:
                    self._compute_decode_input_tokens(self.input_data,
                                                      seq_group_metadata,
                                                      seq_data, seq_id)

    def _compute_decode_input_tokens(self, data: ModelInputData,
                                     seq_group_metadata: SequenceGroupMetadata,
                                     seq_data: SequenceData, seq_id: int):
        """
        Compute decode input tokens, positions, block table and slot mapping.
        """
        block_size = self.runner.block_size

        block_table = seq_group_metadata.block_tables[seq_id]
        seq_len = seq_data.get_len()
        context_len = seq_data.get_num_computed_tokens()

        tokens = seq_data.get_last_token_id()
        token_positions = seq_len - 1
        block_number = block_table[token_positions // block_size]
        block_offset = token_positions % block_size
        slot = block_number * block_size + block_offset

        # For paged_attention kernel
        if self.runner.sliding_window:
            start_idx = max(0, seq_len - self.runner.sliding_window)
            start_block = start_idx // block_size
            start_idx = start_block * block_size
            seq_len = seq_len - start_idx
            block_table = block_table[start_block:]

        # For MRotaryEmbedding
        if data.input_positions is None:
            next_pos = MRotaryEmbedding.get_next_input_positions(
                seq_data.mrope_position_delta,
                context_len,
                seq_len,
            )
            for idx in range(3):
                data.input_mrope_positions[idx].extend(  # type: ignore
                    next_pos[idx])
        else:
            data.input_positions.append(token_positions)  # type: ignore

        # Update fields
        data.input_tokens.append(tokens)
        data.max_decode_seq_len = max(data.max_decode_seq_len, seq_len)
        data.num_decode_tokens += 1
        data.slot_mapping.append(slot)
        data.decode_block_tables.append(block_table)
        data.query_lens.append(1)
        data.seq_lens.append(seq_len)

    def _compute_prompt_input_tokens(self, data: ModelInputData,
                                     seq_group_metadata: SequenceGroupMetadata,
                                     seq_data: SequenceData, seq_id: int):
        """
        Compute prompt input tokens, positions, block table and slot mapping.
        """
        token_chunk_size = seq_group_metadata.token_chunk_size
        block_size = self.runner.block_size

        block_table = seq_group_metadata.block_tables[seq_id]
        seq_len = seq_data.get_len()
        context_len = seq_data.get_num_computed_tokens()
        seq_len = min(seq_len, context_len + token_chunk_size)

        # For prefix caching
        prefix_cache_block_num = len(seq_group_metadata.computed_block_nums)
        if prefix_cache_block_num > 0:
            prefix_cache_len = (prefix_cache_block_num *
                                self.runner.block_size)
            if prefix_cache_len <= context_len:
                # We already passed the cache hit region,
                # so do normal computation.
                pass
            elif context_len < prefix_cache_len < seq_len:
                # Partial hit. Compute the missing part.
                context_len = prefix_cache_len
                token_chunk_size = seq_len - context_len
            elif seq_len <= prefix_cache_len:
                # Full hit. Only compute the last token to avoid
                # erroneous behavior. FIXME: Ideally we should directly
                # mark all tokens as computed in the scheduler and do not
                # schedule this sequence, so this case should not happen.
                context_len = seq_len - 1
                token_chunk_size = 1

        tokens = seq_data.get_token_ids()
        tokens = tokens[context_len:seq_len]
        token_positions = range(context_len, seq_len)
297
        token_types = seq_group_metadata.token_type_ids
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313

        # For encoder-only models, the block_table is None,
        # and there is no need to initialize the slot_mapping.
        if block_table is not None:
            slot_mapping = [_PAD_SLOT_ID] * len(token_positions)
            for i, pos in enumerate(token_positions):
                block_number = block_table[pos // block_size]
                block_offset = pos % block_size
                slot = block_number * block_size + block_offset
                slot_mapping[i] = slot
            data.slot_mapping.extend(slot_mapping)

        # The MROPE positions are prepared in _compute_multi_modal_input
        if data.input_positions is not None:
            data.input_positions.extend(token_positions)

314
315
316
        if data.token_type_ids is not None:
            data.token_type_ids.extend(token_types if token_types else [])

317
318
319
320
321
322
323
324
325
326
327
328
329
330
        # Update fields
        data.input_tokens.extend(tokens)
        data.num_prefills += 1
        data.num_prefill_tokens += len(tokens)
        data.query_lens.append(len(tokens))
        data.prefill_block_tables.append(block_table)
        data.seq_lens.append(seq_len)

    def _compute_multi_modal_input(self,
                                   seq_group_metadata: SequenceGroupMetadata,
                                   seq_data: SequenceData):
        computed_len = seq_data.get_num_computed_tokens()
        seq_len = self.input_data.seq_lens[-1]

331
332
333
        # NOTE: mm_data only includes the subset of multi-modal items that
        # intersect with the current prefill positions.
        mm_data, placeholder_maps = MultiModalPlaceholderMap.from_seq_group(
334
            seq_group_metadata, range(computed_len, seq_len))
335
336

        if not mm_data:
337
            return
338

339
340
341
342
343
344
345
        if self.runner.mm_registry.has_processor(self.runner.model_config):
            mm_kwargs = mm_data
        else:
            mm_kwargs = self.multi_modal_input_mapper(
                mm_data,
                seq_group_metadata.mm_processor_kwargs,
            )
346
347

        # special processing for mrope position deltas.
348
        if self.runner.model_config.uses_mrope:
349
350
351
            assert not self.chunked_prefill, \
                "MROPE on CPU does not support chunked-prefill."

352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
            image_grid_thw = mm_kwargs.get("image_grid_thw", None)
            video_grid_thw = mm_kwargs.get("video_grid_thw", None)
            assert image_grid_thw is not None or video_grid_thw is not None, (
                "mrope embedding type requires multi-modal input mapper "
                "returns 'image_grid_thw' or 'video_grid_thw'.")

            hf_config = self.runner.model_config.hf_config
            token_ids = seq_data.get_token_ids()

            mrope_positions, mrope_position_delta = \
                MRotaryEmbedding.get_input_positions(
                    token_ids,
                    image_grid_thw=image_grid_thw,
                    video_grid_thw=video_grid_thw,
                    image_token_id=hf_config.image_token_id,
                    video_token_id=hf_config.video_token_id,
                    vision_start_token_id=hf_config.vision_start_token_id,
                    vision_end_token_id=hf_config.vision_end_token_id,
                    spatial_merge_size=hf_config.vision_config.
                    spatial_merge_size,
                    context_len=computed_len,
                )
            seq_data.mrope_position_delta = mrope_position_delta

376
377
378
            for i in range(3):
                self.input_data.input_mrope_positions[  # type: ignore
                    i].extend(mrope_positions[i])
379

380
381
382
383
        self.input_data.multi_modal_inputs_list.append(mm_kwargs)
        for modality, placeholder_map in placeholder_maps.items():
            self.input_data.multi_modal_placeholder_maps[modality].extend(
                placeholder_map)
384

385

386
387
388
389
390
391
class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]):
    """
    Helper class for shared methods between CPU model runners.
    """
    _model_input_cls: Type[TModelInputForCPU]
    _builder_cls: Type[ModelInputForCPUBuilder]
392
393
394

    def __init__(
        self,
395
        vllm_config: VllmConfig,
396
397
398
399
400
        kv_cache_dtype: Optional[str] = "auto",
        is_driver_worker: bool = False,
        *args,
        **kwargs,
    ):
401
402
403
404
        ModelRunnerBase.__init__(self, vllm_config)
        model_config = self.model_config
        cache_config = self.cache_config

405
406
407
408
409
410
411
412
413
414
415
416
        self.is_driver_worker = is_driver_worker

        self.device = self.device_config.device

        self.kv_cache_dtype = kv_cache_dtype
        self.sliding_window = model_config.get_sliding_window()
        self.block_size = cache_config.block_size
        self.attn_backend = get_attn_backend(
            self.model_config.get_head_size(),
            self.model_config.dtype,
            self.kv_cache_dtype,
            self.block_size,
417
            self.model_config.is_attention_free,
418
419
420
421
422
423
424
425
426
427
428
429
        )

        # Multi-modal data support
        self.mm_registry = MULTIMODAL_REGISTRY
        self.multi_modal_input_mapper = self.mm_registry \
            .create_input_mapper(self.model_config)
        self.mm_registry.init_mm_limits_per_prompt(self.model_config)

        # Lazy initialization.
        self.model: nn.Module  # Set after init_Model

    def load_model(self) -> None:
430
        self.model = get_model(vllm_config=self.vllm_config)
431
432
433
434
435

    def _prepare_model_input_tensors(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
        finished_requests_ids: Optional[List[str]] = None
436
    ) -> TModelInputForCPU:
437
438
439
440
441
442
        """Helper method to prepare the model input based on a given sequence
        group. Prepares metadata needed for the base model forward pass but not
        metadata for possible additional steps, e.g., sampling.

        """
        builder = self._builder_cls(weakref.proxy(self), finished_requests_ids)
443
        builder.set_seq_group_list(seq_group_metadata_list)
444
445
446

        return builder.build()  # type: ignore

447
448
449
450
451
452
453
454
455
456
457
458
459
460
461

class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]):
    _model_input_cls: Type[ModelInputForCPUWithSamplingMetadata] = (
        ModelInputForCPUWithSamplingMetadata)
    _builder_cls: Type[ModelInputForCPUBuilder] = ModelInputForCPUBuilder

    def make_model_input_from_broadcasted_tensor_dict(
        self,
        tensor_dict: Dict[str, Any],
    ) -> ModelInputForCPUWithSamplingMetadata:
        return ModelInputForCPUWithSamplingMetadata.from_broadcasted_tensor_dict(  # noqa: E501
            tensor_dict,
            attn_backend=self.attn_backend,
        )

462
    def prepare_model_input(
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
        virtual_engine: int = 0,
        finished_requests_ids: Optional[List[str]] = None
    ) -> ModelInputForCPUWithSamplingMetadata:
        """Prepare the model input based on a given sequence group, including
        metadata for the sampling step.

        """
        model_input = self._prepare_model_input_tensors(
            seq_group_metadata_list, finished_requests_ids)
        # Sampling metadata is only required for the final pp group
        generators = self.get_generators(finished_requests_ids)
        sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list,
                                                     model_input.seq_lens,
                                                     model_input.query_lens,
                                                     self.device,
                                                     pin_memory=False,
                                                     generators=generators)

        return dataclasses.replace(model_input,
                                   sampling_metadata=sampling_metadata,
                                   virtual_engine=virtual_engine)
486

487
    @torch.no_grad()
488
489
    def execute_model(
        self,
490
        model_input: ModelInputForCPUWithSamplingMetadata,
491
        kv_caches: List[torch.Tensor],
492
        intermediate_tensors: Optional[IntermediateTensors] = None,
493
494
495
496
497
498
        num_steps: int = 1,
    ) -> Optional[List[SamplerOutput]]:
        if num_steps > 1:
            raise ValueError(
                "CPU worker does not support multi-step execution.")

499
        model_executable = self.model
500
501
502
503
504
        multimodal_kwargs = {}
        if model_input.multi_modal_kwargs is not None:
            multimodal_kwargs = MultiModalKwargs.as_kwargs(
                model_input.multi_modal_kwargs, device=self.device)

youkaichao's avatar
youkaichao committed
505
506
507
508
509
510
511
512
513
        with set_forward_context(model_input.attn_metadata, self.vllm_config):
            hidden_states = model_executable(
                input_ids=model_input.input_tokens,
                positions=model_input.input_positions,
                kv_caches=kv_caches,
                attn_metadata=model_input.attn_metadata,
                intermediate_tensors=intermediate_tensors,
                **multimodal_kwargs,
            )
514
515

        # Compute the logits.
516
517
        logits = self.model.compute_logits(hidden_states,
                                           model_input.sampling_metadata)
518
519

        # Only perform sampling in the driver worker.
520
        if not self.is_driver_worker:
521
            return []
522
523
524
525

        # Sample the next token.
        output = self.model.sample(
            logits=logits,
526
            sampling_metadata=model_input.sampling_metadata,
527
        )
528
        return [output]