cpu_model_runner.py 20.5 KB
Newer Older
1
2
import dataclasses
import weakref
3
from collections import defaultdict
4
from dataclasses import dataclass
5
6
from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Type, TypeVar,
                    Union)
7
8

import torch
9
from torch import nn
10
11

from vllm.attention import AttentionMetadata, get_attn_backend
12
from vllm.config import VllmConfig
13
14
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
15
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
16
from vllm.model_executor.layers.sampler import SamplerOutput
17
from vllm.model_executor.model_loader import get_model
18
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
19
                             MultiModalKwargs, MultiModalPlaceholderMap)
20
21
from vllm.sequence import (IntermediateTensors, SequenceData,
                           SequenceGroupMetadata)
22
from vllm.worker.model_runner_base import (
23
    ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
24
25
26
27
28
29
30
    _add_attn_metadata_broadcastable_dict,
    _add_sampling_metadata_broadcastable_dict,
    _init_attn_metadata_from_tensor_dict,
    _init_sampling_metadata_from_tensor_dict)

if TYPE_CHECKING:
    from vllm.attention.backends.abstract import AttentionBackend
31
32
33

logger = init_logger(__name__)

34
TModelInputForCPU = TypeVar('TModelInputForCPU', bound="ModelInputForCPU")
35
36
37
_PAD_SLOT_ID = -1


38
@dataclass(frozen=True)
39
class ModelInputForCPU(ModelRunnerInputBase):
40
    """
41
    Base class contains metadata needed for the base model forward pass on CPU
42
43
44
45
    """
    input_tokens: Optional[torch.Tensor] = None
    input_positions: Optional[torch.Tensor] = None
    attn_metadata: Optional["AttentionMetadata"] = None
46
    multi_modal_kwargs: Optional[BatchedTensorInputs] = None
47
    virtual_engine: Optional[int] = None
48
49
    seq_lens: Optional[List[int]] = None
    query_lens: Optional[List[int]] = None
50
51
52
53
54
55
56
57
58

    def as_broadcastable_tensor_dict(
            self) -> Dict[str, Union[int, torch.Tensor]]:
        tensor_dict = {
            "input_tokens": self.input_tokens,
            "input_positions": self.input_positions,
            "multi_modal_kwargs": self.multi_modal_kwargs,
        }
        _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
59

60
61
62
63
        return tensor_dict

    @classmethod
    def from_broadcasted_tensor_dict(
64
        cls: Type[TModelInputForCPU],
65
66
        tensor_dict: Dict[str, Any],
        attn_backend: Optional["AttentionBackend"] = None
67
    ) -> TModelInputForCPU:
68
69
70
71
72
73
        if attn_backend is not None:
            tensor_dict = _init_attn_metadata_from_tensor_dict(
                attn_backend, tensor_dict)
        return cls(**tensor_dict)


74
75
76
77
78
79
@dataclass(frozen=True)
class ModelInputForCPUWithSamplingMetadata(ModelInputForCPU):
    """
    Used by the ModelRunner.
    """
    sampling_metadata: Optional["SamplingMetadata"] = None
80

81
82
83
84
85
86
87
88
89
    def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
        tensor_dict = {
            "input_tokens": self.input_tokens,
            "input_positions": self.input_positions,
        }
        _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
        _add_sampling_metadata_broadcastable_dict(tensor_dict,
                                                  self.sampling_metadata)
        return tensor_dict
90

91
92
93
94
95
96
97
98
99
100
101
    @classmethod
    def from_broadcasted_tensor_dict(
        cls,
        tensor_dict: Dict[str, Any],
        attn_backend: Optional["AttentionBackend"] = None,
    ) -> "ModelInputForCPUWithSamplingMetadata":
        tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict)
        if attn_backend is not None:
            tensor_dict = _init_attn_metadata_from_tensor_dict(
                attn_backend, tensor_dict)
        return cls(**tensor_dict)
102
103


104
class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
105

106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
    class ModelInputData:

        def __init__(self, use_mrope: bool):
            self.use_mrope = use_mrope
            self.input_tokens: List[int] = []
            self.input_positions: Optional[
                List[int]] = [] if not self.use_mrope else None
            self.seq_lens: List[int] = []
            self.query_lens: List[int] = []
            self.prefill_block_tables: List[List[int]] = []
            self.decode_block_tables: List[List[int]] = []
            self.max_decode_seq_len: int = 0
            self.num_prefills: int = 0
            self.num_prefill_tokens: int = 0
            self.num_decode_tokens: int = 0
            self.slot_mapping: List[int] = []
            self.multi_modal_inputs_list: List[MultiModalKwargs] = []
            self.multi_modal_placeholder_maps: Dict[
                str, MultiModalPlaceholderMap] = defaultdict(
                    MultiModalPlaceholderMap)
            self.input_mrope_positions: Optional[List[List[int]]] = [
                [] for _ in range(3)
            ] if self.use_mrope else None

130
131
132
133
134
135
    def __init__(self,
                 runner: "CPUModelRunner",
                 finished_requests_ids: Optional[List[str]] = None) -> None:
        super().__init__()
        self.seq_group_metadata_list: List[SequenceGroupMetadata] = []
        self.runner = runner
136
137
138

        self.chunked_prefill = (runner.scheduler_config.chunked_prefill_enabled
                                or runner.cache_config.enable_prefix_caching)
139
140
141
        self.model_input_cls = self.runner._model_input_cls
        self.attn_backend = self.runner.attn_backend
        self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper
142
143
144
145
        self.input_data = ModelInputForCPUBuilder.ModelInputData(
            self.runner.model_config.uses_mrope)
        self.att_metadata_builder = self.runner.attn_backend.get_builder_cls()(
            self)
146

147
148
    def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):
        self.seq_group_metadata_list.append(seq_group_metadata)
149

150
151
152
153
    def set_seq_group_list(
            self, seq_group_metadata_list: List[SequenceGroupMetadata]):
        self.seq_group_metadata_list = seq_group_metadata_list

154
    def build(self) -> ModelInputForCPU:
155
156
157
158
159
160
161
162
163
164
165
166
167
        self._build_input_data()

        input_data = self.input_data
        input_tokens = torch.tensor(input_data.input_tokens,
                                    dtype=torch.long,
                                    device="cpu")
        input_positions = torch.tensor(
            input_data.input_positions
            if not input_data.use_mrope else input_data.input_mrope_positions,
            dtype=torch.long,
            device="cpu")

        # For multi-modal models
168
        multi_modal_kwargs = None
169
170
171
172
173
174
        if len(input_data.multi_modal_inputs_list) != 0:
            multi_modal_kwargs = MultiModalKwargs.batch(
                input_data.multi_modal_inputs_list)

        attn_metadata = self.att_metadata_builder.build(
            input_data.seq_lens, input_data.query_lens, -1, -1)
175
176
177
178

        return self.model_input_cls(
            input_tokens=input_tokens,
            input_positions=input_positions,
179
180
            seq_lens=input_data.seq_lens,
            query_lens=input_data.query_lens,
181
182
183
            attn_metadata=attn_metadata,
            multi_modal_kwargs=multi_modal_kwargs,
        )
184

185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
    def _build_input_data(self):
        for seq_group_metadata in self.seq_group_metadata_list:
            for seq_id, seq_data in seq_group_metadata.seq_data.items():
                if seq_group_metadata.is_prompt:
                    self._compute_prompt_input_tokens(self.input_data,
                                                      seq_group_metadata,
                                                      seq_data, seq_id)
                    if seq_group_metadata.multi_modal_data:
                        self._compute_multi_modal_input(
                            seq_group_metadata, seq_data)
                else:
                    self._compute_decode_input_tokens(self.input_data,
                                                      seq_group_metadata,
                                                      seq_data, seq_id)

    def _compute_decode_input_tokens(self, data: ModelInputData,
                                     seq_group_metadata: SequenceGroupMetadata,
                                     seq_data: SequenceData, seq_id: int):
        """
        Compute decode input tokens, positions, block table and slot mapping.
        """
        block_size = self.runner.block_size

        block_table = seq_group_metadata.block_tables[seq_id]
        seq_len = seq_data.get_len()
        context_len = seq_data.get_num_computed_tokens()

        tokens = seq_data.get_last_token_id()
        token_positions = seq_len - 1
        block_number = block_table[token_positions // block_size]
        block_offset = token_positions % block_size
        slot = block_number * block_size + block_offset

        # For paged_attention kernel
        if self.runner.sliding_window:
            start_idx = max(0, seq_len - self.runner.sliding_window)
            start_block = start_idx // block_size
            start_idx = start_block * block_size
            seq_len = seq_len - start_idx
            block_table = block_table[start_block:]

        # For MRotaryEmbedding
        if data.input_positions is None:
            next_pos = MRotaryEmbedding.get_next_input_positions(
                seq_data.mrope_position_delta,
                context_len,
                seq_len,
            )
            for idx in range(3):
                data.input_mrope_positions[idx].extend(  # type: ignore
                    next_pos[idx])
        else:
            data.input_positions.append(token_positions)  # type: ignore

        # Update fields
        data.input_tokens.append(tokens)
        data.max_decode_seq_len = max(data.max_decode_seq_len, seq_len)
        data.num_decode_tokens += 1
        data.slot_mapping.append(slot)
        data.decode_block_tables.append(block_table)
        data.query_lens.append(1)
        data.seq_lens.append(seq_len)

    def _compute_prompt_input_tokens(self, data: ModelInputData,
                                     seq_group_metadata: SequenceGroupMetadata,
                                     seq_data: SequenceData, seq_id: int):
        """
        Compute prompt input tokens, positions, block table and slot mapping.
        """
        token_chunk_size = seq_group_metadata.token_chunk_size
        block_size = self.runner.block_size

        block_table = seq_group_metadata.block_tables[seq_id]
        seq_len = seq_data.get_len()
        context_len = seq_data.get_num_computed_tokens()
        seq_len = min(seq_len, context_len + token_chunk_size)

        # For prefix caching
        prefix_cache_block_num = len(seq_group_metadata.computed_block_nums)
        if prefix_cache_block_num > 0:
            prefix_cache_len = (prefix_cache_block_num *
                                self.runner.block_size)
            if prefix_cache_len <= context_len:
                # We already passed the cache hit region,
                # so do normal computation.
                pass
            elif context_len < prefix_cache_len < seq_len:
                # Partial hit. Compute the missing part.
                context_len = prefix_cache_len
                token_chunk_size = seq_len - context_len
            elif seq_len <= prefix_cache_len:
                # Full hit. Only compute the last token to avoid
                # erroneous behavior. FIXME: Ideally we should directly
                # mark all tokens as computed in the scheduler and do not
                # schedule this sequence, so this case should not happen.
                context_len = seq_len - 1
                token_chunk_size = 1

        tokens = seq_data.get_token_ids()
        tokens = tokens[context_len:seq_len]
        token_positions = range(context_len, seq_len)

        # For encoder-only models, the block_table is None,
        # and there is no need to initialize the slot_mapping.
        if block_table is not None:
            slot_mapping = [_PAD_SLOT_ID] * len(token_positions)
            for i, pos in enumerate(token_positions):
                block_number = block_table[pos // block_size]
                block_offset = pos % block_size
                slot = block_number * block_size + block_offset
                slot_mapping[i] = slot
            data.slot_mapping.extend(slot_mapping)

        # The MROPE positions are prepared in _compute_multi_modal_input
        if data.input_positions is not None:
            data.input_positions.extend(token_positions)

        # Update fields
        data.input_tokens.extend(tokens)
        data.num_prefills += 1
        data.num_prefill_tokens += len(tokens)
        data.query_lens.append(len(tokens))
        data.prefill_block_tables.append(block_table)
        data.seq_lens.append(seq_len)

    def _compute_multi_modal_input(self,
                                   seq_group_metadata: SequenceGroupMetadata,
                                   seq_data: SequenceData):
        computed_len = seq_data.get_num_computed_tokens()
        seq_len = self.input_data.seq_lens[-1]

316
317
318
        # NOTE: mm_data only includes the subset of multi-modal items that
        # intersect with the current prefill positions.
        mm_data, placeholder_maps = MultiModalPlaceholderMap.from_seq_group(
319
            seq_group_metadata, range(computed_len, seq_len))
320
321

        if not mm_data:
322
            return
323

324
325
326
327
328
329
330
        if self.runner.mm_registry.has_processor(self.runner.model_config):
            mm_kwargs = mm_data
        else:
            mm_kwargs = self.multi_modal_input_mapper(
                mm_data,
                seq_group_metadata.mm_processor_kwargs,
            )
331
332

        # special processing for mrope position deltas.
333
        if self.runner.model_config.uses_mrope:
334
335
336
            assert not self.chunked_prefill, \
                "MROPE on CPU does not support chunked-prefill."

337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
            image_grid_thw = mm_kwargs.get("image_grid_thw", None)
            video_grid_thw = mm_kwargs.get("video_grid_thw", None)
            assert image_grid_thw is not None or video_grid_thw is not None, (
                "mrope embedding type requires multi-modal input mapper "
                "returns 'image_grid_thw' or 'video_grid_thw'.")

            hf_config = self.runner.model_config.hf_config
            token_ids = seq_data.get_token_ids()

            mrope_positions, mrope_position_delta = \
                MRotaryEmbedding.get_input_positions(
                    token_ids,
                    image_grid_thw=image_grid_thw,
                    video_grid_thw=video_grid_thw,
                    image_token_id=hf_config.image_token_id,
                    video_token_id=hf_config.video_token_id,
                    vision_start_token_id=hf_config.vision_start_token_id,
                    vision_end_token_id=hf_config.vision_end_token_id,
                    spatial_merge_size=hf_config.vision_config.
                    spatial_merge_size,
                    context_len=computed_len,
                )
            seq_data.mrope_position_delta = mrope_position_delta

361
362
363
            for i in range(3):
                self.input_data.input_mrope_positions[  # type: ignore
                    i].extend(mrope_positions[i])
364

365
366
367
368
        self.input_data.multi_modal_inputs_list.append(mm_kwargs)
        for modality, placeholder_map in placeholder_maps.items():
            self.input_data.multi_modal_placeholder_maps[modality].extend(
                placeholder_map)
369

370

371
372
373
374
375
376
class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]):
    """
    Helper class for shared methods between CPU model runners.
    """
    _model_input_cls: Type[TModelInputForCPU]
    _builder_cls: Type[ModelInputForCPUBuilder]
377
378
379

    def __init__(
        self,
380
        vllm_config: VllmConfig,
381
382
383
384
385
        kv_cache_dtype: Optional[str] = "auto",
        is_driver_worker: bool = False,
        *args,
        **kwargs,
    ):
386
387
388
389
        ModelRunnerBase.__init__(self, vllm_config)
        model_config = self.model_config
        cache_config = self.cache_config

390
391
392
393
394
395
396
397
398
399
400
401
        self.is_driver_worker = is_driver_worker

        self.device = self.device_config.device

        self.kv_cache_dtype = kv_cache_dtype
        self.sliding_window = model_config.get_sliding_window()
        self.block_size = cache_config.block_size
        self.attn_backend = get_attn_backend(
            self.model_config.get_head_size(),
            self.model_config.dtype,
            self.kv_cache_dtype,
            self.block_size,
402
            self.model_config.is_attention_free,
403
404
405
406
407
408
409
410
411
412
413
414
        )

        # Multi-modal data support
        self.mm_registry = MULTIMODAL_REGISTRY
        self.multi_modal_input_mapper = self.mm_registry \
            .create_input_mapper(self.model_config)
        self.mm_registry.init_mm_limits_per_prompt(self.model_config)

        # Lazy initialization.
        self.model: nn.Module  # Set after init_Model

    def load_model(self) -> None:
415
        self.model = get_model(vllm_config=self.vllm_config)
416
417
418
419
420

    def _prepare_model_input_tensors(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
        finished_requests_ids: Optional[List[str]] = None
421
    ) -> TModelInputForCPU:
422
423
424
425
426
427
        """Helper method to prepare the model input based on a given sequence
        group. Prepares metadata needed for the base model forward pass but not
        metadata for possible additional steps, e.g., sampling.

        """
        builder = self._builder_cls(weakref.proxy(self), finished_requests_ids)
428
        builder.set_seq_group_list(seq_group_metadata_list)
429
430
431

        return builder.build()  # type: ignore

432
433
434
435
436
437
438
439
440
441
442
443
444
445
446

class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]):
    _model_input_cls: Type[ModelInputForCPUWithSamplingMetadata] = (
        ModelInputForCPUWithSamplingMetadata)
    _builder_cls: Type[ModelInputForCPUBuilder] = ModelInputForCPUBuilder

    def make_model_input_from_broadcasted_tensor_dict(
        self,
        tensor_dict: Dict[str, Any],
    ) -> ModelInputForCPUWithSamplingMetadata:
        return ModelInputForCPUWithSamplingMetadata.from_broadcasted_tensor_dict(  # noqa: E501
            tensor_dict,
            attn_backend=self.attn_backend,
        )

447
    def prepare_model_input(
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
        virtual_engine: int = 0,
        finished_requests_ids: Optional[List[str]] = None
    ) -> ModelInputForCPUWithSamplingMetadata:
        """Prepare the model input based on a given sequence group, including
        metadata for the sampling step.

        """
        model_input = self._prepare_model_input_tensors(
            seq_group_metadata_list, finished_requests_ids)
        # Sampling metadata is only required for the final pp group
        generators = self.get_generators(finished_requests_ids)
        sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list,
                                                     model_input.seq_lens,
                                                     model_input.query_lens,
                                                     self.device,
                                                     pin_memory=False,
                                                     generators=generators)

        return dataclasses.replace(model_input,
                                   sampling_metadata=sampling_metadata,
                                   virtual_engine=virtual_engine)
471

472
    @torch.no_grad()
473
474
    def execute_model(
        self,
475
        model_input: ModelInputForCPUWithSamplingMetadata,
476
        kv_caches: List[torch.Tensor],
477
        intermediate_tensors: Optional[IntermediateTensors] = None,
478
479
480
481
482
483
        num_steps: int = 1,
    ) -> Optional[List[SamplerOutput]]:
        if num_steps > 1:
            raise ValueError(
                "CPU worker does not support multi-step execution.")

484
        model_executable = self.model
485
486
487
488
489
490
491
492
493
494
495
496
497
        multimodal_kwargs = {}
        if model_input.multi_modal_kwargs is not None:
            multimodal_kwargs = MultiModalKwargs.as_kwargs(
                model_input.multi_modal_kwargs, device=self.device)

        hidden_states = model_executable(
            input_ids=model_input.input_tokens,
            positions=model_input.input_positions,
            kv_caches=kv_caches,
            attn_metadata=model_input.attn_metadata,
            intermediate_tensors=intermediate_tensors,
            **multimodal_kwargs,
        )
498
499

        # Compute the logits.
500
501
        logits = self.model.compute_logits(hidden_states,
                                           model_input.sampling_metadata)
502
503

        # Only perform sampling in the driver worker.
504
        if not self.is_driver_worker:
505
            return []
506
507
508
509

        # Sample the next token.
        output = self.model.sample(
            logits=logits,
510
            sampling_metadata=model_input.sampling_metadata,
511
        )
512
        return [output]