cpu_model_runner.py 22.8 KB
Newer Older
1
2
import dataclasses
import weakref
3
from collections import defaultdict
4
from dataclasses import dataclass
5
6
from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Type, TypeVar,
                    Union)
7
8

import torch
9
from torch import nn
10
11

from vllm.attention import AttentionMetadata, get_attn_backend
12
from vllm.config import VllmConfig
youkaichao's avatar
youkaichao committed
13
from vllm.forward_context import set_forward_context
14
15
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
16
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
17
from vllm.model_executor.layers.sampler import SamplerOutput
18
from vllm.model_executor.model_loader import get_model
19
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
20
                             MultiModalKwargs, MultiModalPlaceholderMap)
21
22
from vllm.sequence import (IntermediateTensors, SequenceData,
                           SequenceGroupMetadata)
23
from vllm.worker.model_runner_base import (
24
    ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
25
26
27
28
29
30
31
    _add_attn_metadata_broadcastable_dict,
    _add_sampling_metadata_broadcastable_dict,
    _init_attn_metadata_from_tensor_dict,
    _init_sampling_metadata_from_tensor_dict)

if TYPE_CHECKING:
    from vllm.attention.backends.abstract import AttentionBackend
32
33
34

logger = init_logger(__name__)

35
TModelInputForCPU = TypeVar('TModelInputForCPU', bound="ModelInputForCPU")
36
37
38
_PAD_SLOT_ID = -1


39
@dataclass(frozen=True)
40
class ModelInputForCPU(ModelRunnerInputBase):
41
    """
42
    Base class contains metadata needed for the base model forward pass on CPU
43
44
45
    """
    input_tokens: Optional[torch.Tensor] = None
    input_positions: Optional[torch.Tensor] = None
46
    token_type_ids: Optional[torch.Tensor] = None
47
    attn_metadata: Optional["AttentionMetadata"] = None
48
    multi_modal_kwargs: Optional[BatchedTensorInputs] = None
49
    virtual_engine: Optional[int] = None
50
51
    seq_lens: Optional[List[int]] = None
    query_lens: Optional[List[int]] = None
52
53
54
55
56
57

    def as_broadcastable_tensor_dict(
            self) -> Dict[str, Union[int, torch.Tensor]]:
        tensor_dict = {
            "input_tokens": self.input_tokens,
            "input_positions": self.input_positions,
58
            "token_type_ids": self.token_type_ids,
59
60
61
            "multi_modal_kwargs": self.multi_modal_kwargs,
        }
        _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
62

63
64
65
66
        return tensor_dict

    @classmethod
    def from_broadcasted_tensor_dict(
67
        cls: Type[TModelInputForCPU],
68
69
        tensor_dict: Dict[str, Any],
        attn_backend: Optional["AttentionBackend"] = None
70
    ) -> TModelInputForCPU:
71
72
73
74
75
76
        if attn_backend is not None:
            tensor_dict = _init_attn_metadata_from_tensor_dict(
                attn_backend, tensor_dict)
        return cls(**tensor_dict)


77
78
79
80
81
82
@dataclass(frozen=True)
class ModelInputForCPUWithSamplingMetadata(ModelInputForCPU):
    """
    Used by the ModelRunner.
    """
    sampling_metadata: Optional["SamplingMetadata"] = None
83
    is_prompt: Optional[bool] = None
84

85
86
87
88
    def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
        tensor_dict = {
            "input_tokens": self.input_tokens,
            "input_positions": self.input_positions,
89
            "token_type_ids": self.token_type_ids,
90
            "multi_modal_kwargs": self.multi_modal_kwargs,
91
92
93
94
95
        }
        _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
        _add_sampling_metadata_broadcastable_dict(tensor_dict,
                                                  self.sampling_metadata)
        return tensor_dict
96

97
98
99
100
101
102
103
104
105
106
107
    @classmethod
    def from_broadcasted_tensor_dict(
        cls,
        tensor_dict: Dict[str, Any],
        attn_backend: Optional["AttentionBackend"] = None,
    ) -> "ModelInputForCPUWithSamplingMetadata":
        tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict)
        if attn_backend is not None:
            tensor_dict = _init_attn_metadata_from_tensor_dict(
                attn_backend, tensor_dict)
        return cls(**tensor_dict)
108
109


110
class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
111

112
113
114
115
116
    class ModelInputData:

        def __init__(self, use_mrope: bool):
            self.use_mrope = use_mrope
            self.input_tokens: List[int] = []
117
            self.input_positions: List[int] = []
118
            self.token_type_ids: Optional[List[int]] = []
119
120
121
122
123
124
125
126
127
128
129
130
131
            self.seq_lens: List[int] = []
            self.query_lens: List[int] = []
            self.prefill_block_tables: List[List[int]] = []
            self.decode_block_tables: List[List[int]] = []
            self.max_decode_seq_len: int = 0
            self.num_prefills: int = 0
            self.num_prefill_tokens: int = 0
            self.num_decode_tokens: int = 0
            self.slot_mapping: List[int] = []
            self.multi_modal_inputs_list: List[MultiModalKwargs] = []
            self.multi_modal_placeholder_maps: Dict[
                str, MultiModalPlaceholderMap] = defaultdict(
                    MultiModalPlaceholderMap)
132
133
            self.input_mrope_positions: List[List[int]] = [[]
                                                           for _ in range(3)]
134

135
136
137
138
139
140
    def __init__(self,
                 runner: "CPUModelRunner",
                 finished_requests_ids: Optional[List[str]] = None) -> None:
        super().__init__()
        self.seq_group_metadata_list: List[SequenceGroupMetadata] = []
        self.runner = runner
141
142
143

        self.chunked_prefill = (runner.scheduler_config.chunked_prefill_enabled
                                or runner.cache_config.enable_prefix_caching)
144
145
146
        self.model_input_cls = self.runner._model_input_cls
        self.attn_backend = self.runner.attn_backend
        self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper
147
148
149
150
        self.input_data = ModelInputForCPUBuilder.ModelInputData(
            self.runner.model_config.uses_mrope)
        self.att_metadata_builder = self.runner.attn_backend.get_builder_cls()(
            self)
151

152
153
    def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):
        self.seq_group_metadata_list.append(seq_group_metadata)
154

155
156
157
158
    def set_seq_group_list(
            self, seq_group_metadata_list: List[SequenceGroupMetadata]):
        self.seq_group_metadata_list = seq_group_metadata_list

159
    def build(self) -> ModelInputForCPU:
160
161
162
163
164
165
166
167
        self._build_input_data()

        input_data = self.input_data
        input_tokens = torch.tensor(input_data.input_tokens,
                                    dtype=torch.long,
                                    device="cpu")
        input_positions = torch.tensor(
            input_data.input_positions
168
169
            if not any(input_data.input_mrope_positions) else
            input_data.input_mrope_positions,
170
171
            dtype=torch.long,
            device="cpu")
172
173
174
175
        token_type_ids = torch.tensor(input_data.token_type_ids,
                                    dtype=torch.long,
                                    device="cpu") \
                                    if input_data.token_type_ids else None
176
177

        # For multi-modal models
178
        multi_modal_kwargs = None
179
180
181
182
183
184
        if len(input_data.multi_modal_inputs_list) != 0:
            multi_modal_kwargs = MultiModalKwargs.batch(
                input_data.multi_modal_inputs_list)

        attn_metadata = self.att_metadata_builder.build(
            input_data.seq_lens, input_data.query_lens, -1, -1)
185
186
187
188

        return self.model_input_cls(
            input_tokens=input_tokens,
            input_positions=input_positions,
189
            token_type_ids=token_type_ids,
190
191
            seq_lens=input_data.seq_lens,
            query_lens=input_data.query_lens,
192
193
194
            attn_metadata=attn_metadata,
            multi_modal_kwargs=multi_modal_kwargs,
        )
195

196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
    def _build_input_data(self):
        for seq_group_metadata in self.seq_group_metadata_list:
            for seq_id, seq_data in seq_group_metadata.seq_data.items():
                if seq_group_metadata.is_prompt:
                    self._compute_prompt_input_tokens(self.input_data,
                                                      seq_group_metadata,
                                                      seq_data, seq_id)
                    if seq_group_metadata.multi_modal_data:
                        self._compute_multi_modal_input(
                            seq_group_metadata, seq_data)
                else:
                    self._compute_decode_input_tokens(self.input_data,
                                                      seq_group_metadata,
                                                      seq_data, seq_id)

    def _compute_decode_input_tokens(self, data: ModelInputData,
                                     seq_group_metadata: SequenceGroupMetadata,
                                     seq_data: SequenceData, seq_id: int):
        """
        Compute decode input tokens, positions, block table and slot mapping.
        """
        block_size = self.runner.block_size

        block_table = seq_group_metadata.block_tables[seq_id]
        seq_len = seq_data.get_len()
        context_len = seq_data.get_num_computed_tokens()

        tokens = seq_data.get_last_token_id()
        token_positions = seq_len - 1
        block_number = block_table[token_positions // block_size]
        block_offset = token_positions % block_size
        slot = block_number * block_size + block_offset

        # For paged_attention kernel
        if self.runner.sliding_window:
            start_idx = max(0, seq_len - self.runner.sliding_window)
            start_block = start_idx // block_size
            start_idx = start_block * block_size
            seq_len = seq_len - start_idx
            block_table = block_table[start_block:]

        # For MRotaryEmbedding
238
        if seq_data.mrope_position_delta is not None:
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
            next_pos = MRotaryEmbedding.get_next_input_positions(
                seq_data.mrope_position_delta,
                context_len,
                seq_len,
            )
            for idx in range(3):
                data.input_mrope_positions[idx].extend(  # type: ignore
                    next_pos[idx])
        else:
            data.input_positions.append(token_positions)  # type: ignore

        # Update fields
        data.input_tokens.append(tokens)
        data.max_decode_seq_len = max(data.max_decode_seq_len, seq_len)
        data.num_decode_tokens += 1
        data.slot_mapping.append(slot)
        data.decode_block_tables.append(block_table)
        data.query_lens.append(1)
        data.seq_lens.append(seq_len)

    def _compute_prompt_input_tokens(self, data: ModelInputData,
                                     seq_group_metadata: SequenceGroupMetadata,
                                     seq_data: SequenceData, seq_id: int):
        """
        Compute prompt input tokens, positions, block table and slot mapping.
        """
        token_chunk_size = seq_group_metadata.token_chunk_size
        block_size = self.runner.block_size

        block_table = seq_group_metadata.block_tables[seq_id]
        seq_len = seq_data.get_len()
        context_len = seq_data.get_num_computed_tokens()
        seq_len = min(seq_len, context_len + token_chunk_size)

        # For prefix caching
        prefix_cache_block_num = len(seq_group_metadata.computed_block_nums)
        if prefix_cache_block_num > 0:
            prefix_cache_len = (prefix_cache_block_num *
                                self.runner.block_size)
            if prefix_cache_len <= context_len:
                # We already passed the cache hit region,
                # so do normal computation.
                pass
            elif context_len < prefix_cache_len < seq_len:
                # Partial hit. Compute the missing part.
                context_len = prefix_cache_len
                token_chunk_size = seq_len - context_len
            elif seq_len <= prefix_cache_len:
                # Full hit. Only compute the last token to avoid
                # erroneous behavior. FIXME: Ideally we should directly
                # mark all tokens as computed in the scheduler and do not
                # schedule this sequence, so this case should not happen.
                context_len = seq_len - 1
                token_chunk_size = 1

        tokens = seq_data.get_token_ids()
        tokens = tokens[context_len:seq_len]
        token_positions = range(context_len, seq_len)
297
        token_types = seq_group_metadata.token_type_ids
298
299
300
301
302
303
304
305
306
307
308
309
310

        # For encoder-only models, the block_table is None,
        # and there is no need to initialize the slot_mapping.
        if block_table is not None:
            slot_mapping = [_PAD_SLOT_ID] * len(token_positions)
            for i, pos in enumerate(token_positions):
                block_number = block_table[pos // block_size]
                block_offset = pos % block_size
                slot = block_number * block_size + block_offset
                slot_mapping[i] = slot
            data.slot_mapping.extend(slot_mapping)

        # The MROPE positions are prepared in _compute_multi_modal_input
311
        data.input_positions.extend(token_positions)
312

313
314
315
        if data.token_type_ids is not None:
            data.token_type_ids.extend(token_types if token_types else [])

316
317
318
319
320
321
322
323
324
325
326
327
328
329
        # Update fields
        data.input_tokens.extend(tokens)
        data.num_prefills += 1
        data.num_prefill_tokens += len(tokens)
        data.query_lens.append(len(tokens))
        data.prefill_block_tables.append(block_table)
        data.seq_lens.append(seq_len)

    def _compute_multi_modal_input(self,
                                   seq_group_metadata: SequenceGroupMetadata,
                                   seq_data: SequenceData):
        computed_len = seq_data.get_num_computed_tokens()
        seq_len = self.input_data.seq_lens[-1]

330
331
332
        # NOTE: mm_data only includes the subset of multi-modal items that
        # intersect with the current prefill positions.
        mm_data, placeholder_maps = MultiModalPlaceholderMap.from_seq_group(
333
            seq_group_metadata, range(computed_len, seq_len))
334
335

        if not mm_data:
336
            return
337

338
339
340
341
342
343
344
        if self.runner.mm_registry.has_processor(self.runner.model_config):
            mm_kwargs = mm_data
        else:
            mm_kwargs = self.multi_modal_input_mapper(
                mm_data,
                seq_group_metadata.mm_processor_kwargs,
            )
345
346

        # special processing for mrope position deltas.
347
        if self.runner.model_config.uses_mrope:
348
349
350
            assert not self.chunked_prefill, \
                "MROPE on CPU does not support chunked-prefill."

351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
            image_grid_thw = mm_kwargs.get("image_grid_thw", None)
            video_grid_thw = mm_kwargs.get("video_grid_thw", None)
            assert image_grid_thw is not None or video_grid_thw is not None, (
                "mrope embedding type requires multi-modal input mapper "
                "returns 'image_grid_thw' or 'video_grid_thw'.")

            hf_config = self.runner.model_config.hf_config
            token_ids = seq_data.get_token_ids()

            mrope_positions, mrope_position_delta = \
                MRotaryEmbedding.get_input_positions(
                    token_ids,
                    image_grid_thw=image_grid_thw,
                    video_grid_thw=video_grid_thw,
                    image_token_id=hf_config.image_token_id,
                    video_token_id=hf_config.video_token_id,
                    vision_start_token_id=hf_config.vision_start_token_id,
                    vision_end_token_id=hf_config.vision_end_token_id,
                    spatial_merge_size=hf_config.vision_config.
                    spatial_merge_size,
                    context_len=computed_len,
                )
            seq_data.mrope_position_delta = mrope_position_delta

375
376
377
            for i in range(3):
                self.input_data.input_mrope_positions[  # type: ignore
                    i].extend(mrope_positions[i])
378

379
380
381
382
        self.input_data.multi_modal_inputs_list.append(mm_kwargs)
        for modality, placeholder_map in placeholder_maps.items():
            self.input_data.multi_modal_placeholder_maps[modality].extend(
                placeholder_map)
383

384

385
386
387
388
389
390
class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]):
    """
    Helper class for shared methods between CPU model runners.
    """
    _model_input_cls: Type[TModelInputForCPU]
    _builder_cls: Type[ModelInputForCPUBuilder]
391
392
393

    def __init__(
        self,
394
        vllm_config: VllmConfig,
395
396
        kv_cache_dtype: Optional[str] = "auto",
        is_driver_worker: bool = False,
397
        return_hidden_states: bool = False,
398
399
400
        *args,
        **kwargs,
    ):
401
402
403
404
        ModelRunnerBase.__init__(self, vllm_config)
        model_config = self.model_config
        cache_config = self.cache_config

405
        self.is_driver_worker = is_driver_worker
406
        self.return_hidden_states = return_hidden_states
407
408

        self.device = self.device_config.device
409
        self.pin_memory = False
410
411
412
413

        self.kv_cache_dtype = kv_cache_dtype
        self.sliding_window = model_config.get_sliding_window()
        self.block_size = cache_config.block_size
414
415
416
417
        num_attn_heads = self.model_config.get_num_attention_heads(
            self.parallel_config)
        needs_attn_backend = (num_attn_heads != 0
                              or self.model_config.is_attention_free)
418
419
420
421
422
        self.attn_backend = get_attn_backend(
            self.model_config.get_head_size(),
            self.model_config.dtype,
            self.kv_cache_dtype,
            self.block_size,
423
            self.model_config.is_attention_free,
424
        ) if needs_attn_backend else None
425
426
427
428
429
430
431
432
433
434
435

        # Multi-modal data support
        self.mm_registry = MULTIMODAL_REGISTRY
        self.multi_modal_input_mapper = self.mm_registry \
            .create_input_mapper(self.model_config)
        self.mm_registry.init_mm_limits_per_prompt(self.model_config)

        # Lazy initialization.
        self.model: nn.Module  # Set after init_Model

    def load_model(self) -> None:
436
        self.model = get_model(vllm_config=self.vllm_config)
437
438
439
440
441

    def _prepare_model_input_tensors(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
        finished_requests_ids: Optional[List[str]] = None
442
    ) -> TModelInputForCPU:
443
444
445
446
447
448
        """Helper method to prepare the model input based on a given sequence
        group. Prepares metadata needed for the base model forward pass but not
        metadata for possible additional steps, e.g., sampling.

        """
        builder = self._builder_cls(weakref.proxy(self), finished_requests_ids)
449
        builder.set_seq_group_list(seq_group_metadata_list)
450
451
452

        return builder.build()  # type: ignore

453
454
455
456
457
458
459
460
461
    # sampler property will be used by spec_decode_worker
    @property
    def sampler(self):
        return self.model.sampler

    @property
    def vocab_size(self) -> int:
        return self.model_config.get_vocab_size()

462
463
464
465
466
467
468
469
470
471
472
473
474
475
476

class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]):
    _model_input_cls: Type[ModelInputForCPUWithSamplingMetadata] = (
        ModelInputForCPUWithSamplingMetadata)
    _builder_cls: Type[ModelInputForCPUBuilder] = ModelInputForCPUBuilder

    def make_model_input_from_broadcasted_tensor_dict(
        self,
        tensor_dict: Dict[str, Any],
    ) -> ModelInputForCPUWithSamplingMetadata:
        return ModelInputForCPUWithSamplingMetadata.from_broadcasted_tensor_dict(  # noqa: E501
            tensor_dict,
            attn_backend=self.attn_backend,
        )

477
    def prepare_model_input(
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
        virtual_engine: int = 0,
        finished_requests_ids: Optional[List[str]] = None
    ) -> ModelInputForCPUWithSamplingMetadata:
        """Prepare the model input based on a given sequence group, including
        metadata for the sampling step.

        """
        model_input = self._prepare_model_input_tensors(
            seq_group_metadata_list, finished_requests_ids)
        # Sampling metadata is only required for the final pp group
        generators = self.get_generators(finished_requests_ids)
        sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list,
                                                     model_input.seq_lens,
                                                     model_input.query_lens,
                                                     self.device,
                                                     pin_memory=False,
                                                     generators=generators)

498
499
        is_prompt = (seq_group_metadata_list[0].is_prompt
                     if seq_group_metadata_list else None)
500
501
        return dataclasses.replace(model_input,
                                   sampling_metadata=sampling_metadata,
502
503
                                   virtual_engine=virtual_engine,
                                   is_prompt=is_prompt)
504

505
    @torch.no_grad()
506
507
    def execute_model(
        self,
508
        model_input: ModelInputForCPUWithSamplingMetadata,
509
        kv_caches: List[torch.Tensor],
510
        intermediate_tensors: Optional[IntermediateTensors] = None,
511
        num_steps: int = 1,
512
        previous_hidden_states: Optional[torch.Tensor] = None,
513
514
515
516
517
    ) -> Optional[List[SamplerOutput]]:
        if num_steps > 1:
            raise ValueError(
                "CPU worker does not support multi-step execution.")

518
        model_executable = self.model
519

520
521
522
523
        multimodal_kwargs = {}
        if model_input.multi_modal_kwargs is not None:
            multimodal_kwargs = MultiModalKwargs.as_kwargs(
                model_input.multi_modal_kwargs, device=self.device)
524
525
526
527
        execute_model_kwargs = {}
        if previous_hidden_states is not None:
            execute_model_kwargs.update(
                {"previous_hidden_states": previous_hidden_states})
528

youkaichao's avatar
youkaichao committed
529
530
531
532
533
534
535
        with set_forward_context(model_input.attn_metadata, self.vllm_config):
            hidden_states = model_executable(
                input_ids=model_input.input_tokens,
                positions=model_input.input_positions,
                kv_caches=kv_caches,
                attn_metadata=model_input.attn_metadata,
                intermediate_tensors=intermediate_tensors,
536
                **execute_model_kwargs,
youkaichao's avatar
youkaichao committed
537
538
                **multimodal_kwargs,
            )
539
540

        # Compute the logits.
541
542
        logits = self.model.compute_logits(hidden_states,
                                           model_input.sampling_metadata)
543
544

        # Only perform sampling in the driver worker.
545
        if not self.is_driver_worker:
546
            return []
547
548
549
550

        # Sample the next token.
        output = self.model.sample(
            logits=logits,
551
            sampling_metadata=model_input.sampling_metadata,
552
        )
553
554
555
556
557
        if self.return_hidden_states:
            # we only need to pass hidden states of most recent token
            if model_input.is_prompt:
                output.prefill_hidden_states = hidden_states
            output.hidden_states = hidden_states
558
        return [output]
559
560
561

    def generate_proposals(self, *args, **kwargs):
        return self.model.generate_proposals(*args, **kwargs)