cpu_model_runner.py 22.9 KB
Newer Older
1
2
import dataclasses
import weakref
3
from collections import defaultdict
4
from dataclasses import dataclass
5
6
from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Type, TypeVar,
                    Union)
7
8

import torch
9
from torch import nn
10
11

from vllm.attention import AttentionMetadata, get_attn_backend
12
from vllm.config import VllmConfig
youkaichao's avatar
youkaichao committed
13
from vllm.forward_context import set_forward_context
14
15
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
16
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
17
from vllm.model_executor.layers.sampler import SamplerOutput
18
from vllm.model_executor.model_loader import get_model
19
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
20
                             MultiModalKwargs, MultiModalPlaceholderMap)
21
22
from vllm.sequence import (IntermediateTensors, SequenceData,
                           SequenceGroupMetadata)
23
from vllm.worker.model_runner_base import (
24
    ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
25
26
27
28
29
30
31
    _add_attn_metadata_broadcastable_dict,
    _add_sampling_metadata_broadcastable_dict,
    _init_attn_metadata_from_tensor_dict,
    _init_sampling_metadata_from_tensor_dict)

if TYPE_CHECKING:
    from vllm.attention.backends.abstract import AttentionBackend
32
33
34

logger = init_logger(__name__)

35
TModelInputForCPU = TypeVar('TModelInputForCPU', bound="ModelInputForCPU")
36
37
38
_PAD_SLOT_ID = -1


39
@dataclass(frozen=True)
40
class ModelInputForCPU(ModelRunnerInputBase):
41
    """
42
    Base class contains metadata needed for the base model forward pass on CPU
43
44
45
    """
    input_tokens: Optional[torch.Tensor] = None
    input_positions: Optional[torch.Tensor] = None
46
    token_type_ids: Optional[torch.Tensor] = None
47
    attn_metadata: Optional["AttentionMetadata"] = None
48
    multi_modal_kwargs: Optional[BatchedTensorInputs] = None
49
    virtual_engine: Optional[int] = None
50
51
    seq_lens: Optional[List[int]] = None
    query_lens: Optional[List[int]] = None
52
53
54
55
56
57

    def as_broadcastable_tensor_dict(
            self) -> Dict[str, Union[int, torch.Tensor]]:
        tensor_dict = {
            "input_tokens": self.input_tokens,
            "input_positions": self.input_positions,
58
            "token_type_ids": self.token_type_ids,
59
60
61
            "multi_modal_kwargs": self.multi_modal_kwargs,
        }
        _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
62

63
64
65
66
        return tensor_dict

    @classmethod
    def from_broadcasted_tensor_dict(
67
        cls: Type[TModelInputForCPU],
68
69
        tensor_dict: Dict[str, Any],
        attn_backend: Optional["AttentionBackend"] = None
70
    ) -> TModelInputForCPU:
71
72
73
74
75
76
        if attn_backend is not None:
            tensor_dict = _init_attn_metadata_from_tensor_dict(
                attn_backend, tensor_dict)
        return cls(**tensor_dict)


77
78
79
80
81
82
@dataclass(frozen=True)
class ModelInputForCPUWithSamplingMetadata(ModelInputForCPU):
    """
    Used by the ModelRunner.
    """
    sampling_metadata: Optional["SamplingMetadata"] = None
83
    is_prompt: Optional[bool] = None
84

85
86
87
88
    def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
        tensor_dict = {
            "input_tokens": self.input_tokens,
            "input_positions": self.input_positions,
89
            "token_type_ids": self.token_type_ids,
90
            "multi_modal_kwargs": self.multi_modal_kwargs,
91
92
93
94
95
        }
        _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
        _add_sampling_metadata_broadcastable_dict(tensor_dict,
                                                  self.sampling_metadata)
        return tensor_dict
96

97
98
99
100
101
102
103
104
105
106
107
    @classmethod
    def from_broadcasted_tensor_dict(
        cls,
        tensor_dict: Dict[str, Any],
        attn_backend: Optional["AttentionBackend"] = None,
    ) -> "ModelInputForCPUWithSamplingMetadata":
        tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict)
        if attn_backend is not None:
            tensor_dict = _init_attn_metadata_from_tensor_dict(
                attn_backend, tensor_dict)
        return cls(**tensor_dict)
108
109


110
class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
111

112
113
114
115
116
117
118
    class ModelInputData:

        def __init__(self, use_mrope: bool):
            self.use_mrope = use_mrope
            self.input_tokens: List[int] = []
            self.input_positions: Optional[
                List[int]] = [] if not self.use_mrope else None
119
            self.token_type_ids: Optional[List[int]] = []
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
            self.seq_lens: List[int] = []
            self.query_lens: List[int] = []
            self.prefill_block_tables: List[List[int]] = []
            self.decode_block_tables: List[List[int]] = []
            self.max_decode_seq_len: int = 0
            self.num_prefills: int = 0
            self.num_prefill_tokens: int = 0
            self.num_decode_tokens: int = 0
            self.slot_mapping: List[int] = []
            self.multi_modal_inputs_list: List[MultiModalKwargs] = []
            self.multi_modal_placeholder_maps: Dict[
                str, MultiModalPlaceholderMap] = defaultdict(
                    MultiModalPlaceholderMap)
            self.input_mrope_positions: Optional[List[List[int]]] = [
                [] for _ in range(3)
            ] if self.use_mrope else None

137
138
139
140
141
142
    def __init__(self,
                 runner: "CPUModelRunner",
                 finished_requests_ids: Optional[List[str]] = None) -> None:
        super().__init__()
        self.seq_group_metadata_list: List[SequenceGroupMetadata] = []
        self.runner = runner
143
144
145

        self.chunked_prefill = (runner.scheduler_config.chunked_prefill_enabled
                                or runner.cache_config.enable_prefix_caching)
146
147
148
        self.model_input_cls = self.runner._model_input_cls
        self.attn_backend = self.runner.attn_backend
        self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper
149
150
151
152
        self.input_data = ModelInputForCPUBuilder.ModelInputData(
            self.runner.model_config.uses_mrope)
        self.att_metadata_builder = self.runner.attn_backend.get_builder_cls()(
            self)
153

154
155
    def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):
        self.seq_group_metadata_list.append(seq_group_metadata)
156

157
158
159
160
    def set_seq_group_list(
            self, seq_group_metadata_list: List[SequenceGroupMetadata]):
        self.seq_group_metadata_list = seq_group_metadata_list

161
    def build(self) -> ModelInputForCPU:
162
163
164
165
166
167
168
169
170
171
172
        self._build_input_data()

        input_data = self.input_data
        input_tokens = torch.tensor(input_data.input_tokens,
                                    dtype=torch.long,
                                    device="cpu")
        input_positions = torch.tensor(
            input_data.input_positions
            if not input_data.use_mrope else input_data.input_mrope_positions,
            dtype=torch.long,
            device="cpu")
173
174
175
176
        token_type_ids = torch.tensor(input_data.token_type_ids,
                                    dtype=torch.long,
                                    device="cpu") \
                                    if input_data.token_type_ids else None
177
178

        # For multi-modal models
179
        multi_modal_kwargs = None
180
181
182
183
184
185
        if len(input_data.multi_modal_inputs_list) != 0:
            multi_modal_kwargs = MultiModalKwargs.batch(
                input_data.multi_modal_inputs_list)

        attn_metadata = self.att_metadata_builder.build(
            input_data.seq_lens, input_data.query_lens, -1, -1)
186
187
188
189

        return self.model_input_cls(
            input_tokens=input_tokens,
            input_positions=input_positions,
190
            token_type_ids=token_type_ids,
191
192
            seq_lens=input_data.seq_lens,
            query_lens=input_data.query_lens,
193
194
195
            attn_metadata=attn_metadata,
            multi_modal_kwargs=multi_modal_kwargs,
        )
196

197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
    def _build_input_data(self):
        for seq_group_metadata in self.seq_group_metadata_list:
            for seq_id, seq_data in seq_group_metadata.seq_data.items():
                if seq_group_metadata.is_prompt:
                    self._compute_prompt_input_tokens(self.input_data,
                                                      seq_group_metadata,
                                                      seq_data, seq_id)
                    if seq_group_metadata.multi_modal_data:
                        self._compute_multi_modal_input(
                            seq_group_metadata, seq_data)
                else:
                    self._compute_decode_input_tokens(self.input_data,
                                                      seq_group_metadata,
                                                      seq_data, seq_id)

    def _compute_decode_input_tokens(self, data: ModelInputData,
                                     seq_group_metadata: SequenceGroupMetadata,
                                     seq_data: SequenceData, seq_id: int):
        """
        Compute decode input tokens, positions, block table and slot mapping.
        """
        block_size = self.runner.block_size

        block_table = seq_group_metadata.block_tables[seq_id]
        seq_len = seq_data.get_len()
        context_len = seq_data.get_num_computed_tokens()

        tokens = seq_data.get_last_token_id()
        token_positions = seq_len - 1
        block_number = block_table[token_positions // block_size]
        block_offset = token_positions % block_size
        slot = block_number * block_size + block_offset

        # For paged_attention kernel
        if self.runner.sliding_window:
            start_idx = max(0, seq_len - self.runner.sliding_window)
            start_block = start_idx // block_size
            start_idx = start_block * block_size
            seq_len = seq_len - start_idx
            block_table = block_table[start_block:]

        # For MRotaryEmbedding
        if data.input_positions is None:
            next_pos = MRotaryEmbedding.get_next_input_positions(
                seq_data.mrope_position_delta,
                context_len,
                seq_len,
            )
            for idx in range(3):
                data.input_mrope_positions[idx].extend(  # type: ignore
                    next_pos[idx])
        else:
            data.input_positions.append(token_positions)  # type: ignore

        # Update fields
        data.input_tokens.append(tokens)
        data.max_decode_seq_len = max(data.max_decode_seq_len, seq_len)
        data.num_decode_tokens += 1
        data.slot_mapping.append(slot)
        data.decode_block_tables.append(block_table)
        data.query_lens.append(1)
        data.seq_lens.append(seq_len)

    def _compute_prompt_input_tokens(self, data: ModelInputData,
                                     seq_group_metadata: SequenceGroupMetadata,
                                     seq_data: SequenceData, seq_id: int):
        """
        Compute prompt input tokens, positions, block table and slot mapping.
        """
        token_chunk_size = seq_group_metadata.token_chunk_size
        block_size = self.runner.block_size

        block_table = seq_group_metadata.block_tables[seq_id]
        seq_len = seq_data.get_len()
        context_len = seq_data.get_num_computed_tokens()
        seq_len = min(seq_len, context_len + token_chunk_size)

        # For prefix caching
        prefix_cache_block_num = len(seq_group_metadata.computed_block_nums)
        if prefix_cache_block_num > 0:
            prefix_cache_len = (prefix_cache_block_num *
                                self.runner.block_size)
            if prefix_cache_len <= context_len:
                # We already passed the cache hit region,
                # so do normal computation.
                pass
            elif context_len < prefix_cache_len < seq_len:
                # Partial hit. Compute the missing part.
                context_len = prefix_cache_len
                token_chunk_size = seq_len - context_len
            elif seq_len <= prefix_cache_len:
                # Full hit. Only compute the last token to avoid
                # erroneous behavior. FIXME: Ideally we should directly
                # mark all tokens as computed in the scheduler and do not
                # schedule this sequence, so this case should not happen.
                context_len = seq_len - 1
                token_chunk_size = 1

        tokens = seq_data.get_token_ids()
        tokens = tokens[context_len:seq_len]
        token_positions = range(context_len, seq_len)
298
        token_types = seq_group_metadata.token_type_ids
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314

        # For encoder-only models, the block_table is None,
        # and there is no need to initialize the slot_mapping.
        if block_table is not None:
            slot_mapping = [_PAD_SLOT_ID] * len(token_positions)
            for i, pos in enumerate(token_positions):
                block_number = block_table[pos // block_size]
                block_offset = pos % block_size
                slot = block_number * block_size + block_offset
                slot_mapping[i] = slot
            data.slot_mapping.extend(slot_mapping)

        # The MROPE positions are prepared in _compute_multi_modal_input
        if data.input_positions is not None:
            data.input_positions.extend(token_positions)

315
316
317
        if data.token_type_ids is not None:
            data.token_type_ids.extend(token_types if token_types else [])

318
319
320
321
322
323
324
325
326
327
328
329
330
331
        # Update fields
        data.input_tokens.extend(tokens)
        data.num_prefills += 1
        data.num_prefill_tokens += len(tokens)
        data.query_lens.append(len(tokens))
        data.prefill_block_tables.append(block_table)
        data.seq_lens.append(seq_len)

    def _compute_multi_modal_input(self,
                                   seq_group_metadata: SequenceGroupMetadata,
                                   seq_data: SequenceData):
        computed_len = seq_data.get_num_computed_tokens()
        seq_len = self.input_data.seq_lens[-1]

332
333
334
        # NOTE: mm_data only includes the subset of multi-modal items that
        # intersect with the current prefill positions.
        mm_data, placeholder_maps = MultiModalPlaceholderMap.from_seq_group(
335
            seq_group_metadata, range(computed_len, seq_len))
336
337

        if not mm_data:
338
            return
339

340
341
342
343
344
345
346
        if self.runner.mm_registry.has_processor(self.runner.model_config):
            mm_kwargs = mm_data
        else:
            mm_kwargs = self.multi_modal_input_mapper(
                mm_data,
                seq_group_metadata.mm_processor_kwargs,
            )
347
348

        # special processing for mrope position deltas.
349
        if self.runner.model_config.uses_mrope:
350
351
352
            assert not self.chunked_prefill, \
                "MROPE on CPU does not support chunked-prefill."

353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
            image_grid_thw = mm_kwargs.get("image_grid_thw", None)
            video_grid_thw = mm_kwargs.get("video_grid_thw", None)
            assert image_grid_thw is not None or video_grid_thw is not None, (
                "mrope embedding type requires multi-modal input mapper "
                "returns 'image_grid_thw' or 'video_grid_thw'.")

            hf_config = self.runner.model_config.hf_config
            token_ids = seq_data.get_token_ids()

            mrope_positions, mrope_position_delta = \
                MRotaryEmbedding.get_input_positions(
                    token_ids,
                    image_grid_thw=image_grid_thw,
                    video_grid_thw=video_grid_thw,
                    image_token_id=hf_config.image_token_id,
                    video_token_id=hf_config.video_token_id,
                    vision_start_token_id=hf_config.vision_start_token_id,
                    vision_end_token_id=hf_config.vision_end_token_id,
                    spatial_merge_size=hf_config.vision_config.
                    spatial_merge_size,
                    context_len=computed_len,
                )
            seq_data.mrope_position_delta = mrope_position_delta

377
378
379
            for i in range(3):
                self.input_data.input_mrope_positions[  # type: ignore
                    i].extend(mrope_positions[i])
380

381
382
383
384
        self.input_data.multi_modal_inputs_list.append(mm_kwargs)
        for modality, placeholder_map in placeholder_maps.items():
            self.input_data.multi_modal_placeholder_maps[modality].extend(
                placeholder_map)
385

386

387
388
389
390
391
392
class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]):
    """
    Helper class for shared methods between CPU model runners.
    """
    _model_input_cls: Type[TModelInputForCPU]
    _builder_cls: Type[ModelInputForCPUBuilder]
393
394
395

    def __init__(
        self,
396
        vllm_config: VllmConfig,
397
398
        kv_cache_dtype: Optional[str] = "auto",
        is_driver_worker: bool = False,
399
        return_hidden_states: bool = False,
400
401
402
        *args,
        **kwargs,
    ):
403
404
405
406
        ModelRunnerBase.__init__(self, vllm_config)
        model_config = self.model_config
        cache_config = self.cache_config

407
        self.is_driver_worker = is_driver_worker
408
        self.return_hidden_states = return_hidden_states
409
410

        self.device = self.device_config.device
411
        self.pin_memory = False
412
413
414
415

        self.kv_cache_dtype = kv_cache_dtype
        self.sliding_window = model_config.get_sliding_window()
        self.block_size = cache_config.block_size
416
417
418
419
        num_attn_heads = self.model_config.get_num_attention_heads(
            self.parallel_config)
        needs_attn_backend = (num_attn_heads != 0
                              or self.model_config.is_attention_free)
420
421
422
423
424
        self.attn_backend = get_attn_backend(
            self.model_config.get_head_size(),
            self.model_config.dtype,
            self.kv_cache_dtype,
            self.block_size,
425
            self.model_config.is_attention_free,
426
        ) if needs_attn_backend else None
427
428
429
430
431
432
433
434
435
436
437

        # Multi-modal data support
        self.mm_registry = MULTIMODAL_REGISTRY
        self.multi_modal_input_mapper = self.mm_registry \
            .create_input_mapper(self.model_config)
        self.mm_registry.init_mm_limits_per_prompt(self.model_config)

        # Lazy initialization.
        self.model: nn.Module  # Set after init_Model

    def load_model(self) -> None:
438
        self.model = get_model(vllm_config=self.vllm_config)
439
440
441
442
443

    def _prepare_model_input_tensors(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
        finished_requests_ids: Optional[List[str]] = None
444
    ) -> TModelInputForCPU:
445
446
447
448
449
450
        """Helper method to prepare the model input based on a given sequence
        group. Prepares metadata needed for the base model forward pass but not
        metadata for possible additional steps, e.g., sampling.

        """
        builder = self._builder_cls(weakref.proxy(self), finished_requests_ids)
451
        builder.set_seq_group_list(seq_group_metadata_list)
452
453
454

        return builder.build()  # type: ignore

455
456
457
458
459
460
461
462
463
    # sampler property will be used by spec_decode_worker
    @property
    def sampler(self):
        return self.model.sampler

    @property
    def vocab_size(self) -> int:
        return self.model_config.get_vocab_size()

464
465
466
467
468
469
470
471
472
473
474
475
476
477
478

class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]):
    _model_input_cls: Type[ModelInputForCPUWithSamplingMetadata] = (
        ModelInputForCPUWithSamplingMetadata)
    _builder_cls: Type[ModelInputForCPUBuilder] = ModelInputForCPUBuilder

    def make_model_input_from_broadcasted_tensor_dict(
        self,
        tensor_dict: Dict[str, Any],
    ) -> ModelInputForCPUWithSamplingMetadata:
        return ModelInputForCPUWithSamplingMetadata.from_broadcasted_tensor_dict(  # noqa: E501
            tensor_dict,
            attn_backend=self.attn_backend,
        )

479
    def prepare_model_input(
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
        virtual_engine: int = 0,
        finished_requests_ids: Optional[List[str]] = None
    ) -> ModelInputForCPUWithSamplingMetadata:
        """Prepare the model input based on a given sequence group, including
        metadata for the sampling step.

        """
        model_input = self._prepare_model_input_tensors(
            seq_group_metadata_list, finished_requests_ids)
        # Sampling metadata is only required for the final pp group
        generators = self.get_generators(finished_requests_ids)
        sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list,
                                                     model_input.seq_lens,
                                                     model_input.query_lens,
                                                     self.device,
                                                     pin_memory=False,
                                                     generators=generators)

500
501
        is_prompt = (seq_group_metadata_list[0].is_prompt
                     if seq_group_metadata_list else None)
502
503
        return dataclasses.replace(model_input,
                                   sampling_metadata=sampling_metadata,
504
505
                                   virtual_engine=virtual_engine,
                                   is_prompt=is_prompt)
506

507
    @torch.no_grad()
508
509
    def execute_model(
        self,
510
        model_input: ModelInputForCPUWithSamplingMetadata,
511
        kv_caches: List[torch.Tensor],
512
        intermediate_tensors: Optional[IntermediateTensors] = None,
513
        num_steps: int = 1,
514
        previous_hidden_states: Optional[torch.Tensor] = None,
515
516
517
518
519
    ) -> Optional[List[SamplerOutput]]:
        if num_steps > 1:
            raise ValueError(
                "CPU worker does not support multi-step execution.")

520
        model_executable = self.model
521

522
523
524
525
        multimodal_kwargs = {}
        if model_input.multi_modal_kwargs is not None:
            multimodal_kwargs = MultiModalKwargs.as_kwargs(
                model_input.multi_modal_kwargs, device=self.device)
526
527
528
529
        execute_model_kwargs = {}
        if previous_hidden_states is not None:
            execute_model_kwargs.update(
                {"previous_hidden_states": previous_hidden_states})
530

youkaichao's avatar
youkaichao committed
531
532
533
534
535
536
537
        with set_forward_context(model_input.attn_metadata, self.vllm_config):
            hidden_states = model_executable(
                input_ids=model_input.input_tokens,
                positions=model_input.input_positions,
                kv_caches=kv_caches,
                attn_metadata=model_input.attn_metadata,
                intermediate_tensors=intermediate_tensors,
538
                **execute_model_kwargs,
youkaichao's avatar
youkaichao committed
539
540
                **multimodal_kwargs,
            )
541
542

        # Compute the logits.
543
544
        logits = self.model.compute_logits(hidden_states,
                                           model_input.sampling_metadata)
545
546

        # Only perform sampling in the driver worker.
547
        if not self.is_driver_worker:
548
            return []
549
550
551
552

        # Sample the next token.
        output = self.model.sample(
            logits=logits,
553
            sampling_metadata=model_input.sampling_metadata,
554
        )
555
556
557
558
559
        if self.return_hidden_states:
            # we only need to pass hidden states of most recent token
            if model_input.is_prompt:
                output.prefill_hidden_states = hidden_states
            output.hidden_states = hidden_states
560
        return [output]
561
562
563

    def generate_proposals(self, *args, **kwargs):
        return self.model.generate_proposals(*args, **kwargs)