"vllm/vscode:/vscode.git/clone" did not exist on "54e2f83d0a82462e0128e5d852e3d46fbb566a7f"
cpu_model_runner.py 21.6 KB
Newer Older
1
2
import dataclasses
import weakref
3
from collections import defaultdict
4
from dataclasses import dataclass
5
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union
6
7

import torch
8
from torch import nn
9
10

from vllm.attention import AttentionMetadata, get_attn_backend
11
from vllm.config import VllmConfig
12
13
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
14
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
15
from vllm.model_executor.layers.sampler import SamplerOutput
16
from vllm.model_executor.model_loader import get_model
17
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
18
                             MultiModalInputs, MultiModalPlaceholderMap)
19
20
from vllm.sequence import (IntermediateTensors, SequenceData,
                           SequenceGroupMetadata)
21
from vllm.utils import make_tensor_with_pad
22
from vllm.worker.model_runner_base import (
23
    ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
24
25
26
27
28
29
30
    _add_attn_metadata_broadcastable_dict,
    _add_sampling_metadata_broadcastable_dict,
    _init_attn_metadata_from_tensor_dict,
    _init_sampling_metadata_from_tensor_dict)

if TYPE_CHECKING:
    from vllm.attention.backends.abstract import AttentionBackend
31
32
33
34
35
36

logger = init_logger(__name__)

_PAD_SLOT_ID = -1


37
@dataclass(frozen=True)
38
class ModelInputForCPU(ModelRunnerInputBase):
39
    """
40
    Base class contains metadata needed for the base model forward pass on CPU
41
42
43
44
    """
    input_tokens: Optional[torch.Tensor] = None
    input_positions: Optional[torch.Tensor] = None
    attn_metadata: Optional["AttentionMetadata"] = None
45
    multi_modal_kwargs: Optional[BatchedTensorInputs] = None
46
    virtual_engine: Optional[int] = None
47
48
    seq_lens: Optional[List[int]] = None
    query_lens: Optional[List[int]] = None
49
50
51
52
53
54
55
56
57

    def as_broadcastable_tensor_dict(
            self) -> Dict[str, Union[int, torch.Tensor]]:
        tensor_dict = {
            "input_tokens": self.input_tokens,
            "input_positions": self.input_positions,
            "multi_modal_kwargs": self.multi_modal_kwargs,
        }
        _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
58

59
60
61
62
        return tensor_dict

    @classmethod
    def from_broadcasted_tensor_dict(
63
64
65
66
        cls: Type["ModelInputForCPU"],
        tensor_dict: Dict[str, Any],
        attn_backend: Optional["AttentionBackend"] = None
    ) -> "ModelInputForCPU":
67
68
69
70
71
72
        if attn_backend is not None:
            tensor_dict = _init_attn_metadata_from_tensor_dict(
                attn_backend, tensor_dict)
        return cls(**tensor_dict)


73
74
75
76
77
78
@dataclass(frozen=True)
class ModelInputForCPUWithSamplingMetadata(ModelInputForCPU):
    """
    Used by the ModelRunner.
    """
    sampling_metadata: Optional["SamplingMetadata"] = None
79

80
81
82
83
84
85
86
87
88
    def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
        tensor_dict = {
            "input_tokens": self.input_tokens,
            "input_positions": self.input_positions,
        }
        _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
        _add_sampling_metadata_broadcastable_dict(tensor_dict,
                                                  self.sampling_metadata)
        return tensor_dict
89

90
91
92
93
94
95
96
97
98
99
100
    @classmethod
    def from_broadcasted_tensor_dict(
        cls,
        tensor_dict: Dict[str, Any],
        attn_backend: Optional["AttentionBackend"] = None,
    ) -> "ModelInputForCPUWithSamplingMetadata":
        tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict)
        if attn_backend is not None:
            tensor_dict = _init_attn_metadata_from_tensor_dict(
                attn_backend, tensor_dict)
        return cls(**tensor_dict)
101
102


103
class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
104

105
106
107
108
109
110
111
112
113
114
115
116
    def __init__(self,
                 runner: "CPUModelRunner",
                 finished_requests_ids: Optional[List[str]] = None) -> None:
        super().__init__()
        self.seq_group_metadata_list: List[SequenceGroupMetadata] = []
        self.runner = runner
        self.model_input_cls = self.runner._model_input_cls
        self.attn_backend = self.runner.attn_backend
        self.sliding_window = self.runner.sliding_window
        self.block_size = self.runner.block_size
        self.device = self.runner.device
        self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper
117

118
119
    def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):
        self.seq_group_metadata_list.append(seq_group_metadata)
120

121
122
123
124
125
126
127
128
129
130
131
132
133
134
    def build(self) -> ModelInputForCPU:
        multi_modal_kwargs = None
        # NOTE: We assume that all sequences in the group are all prompts or
        # all decodes.
        is_prompt = self.seq_group_metadata_list[0].is_prompt
        # Prepare input tensors.
        if is_prompt:
            (input_tokens, input_positions, attn_metadata, seq_lens,
             multi_modal_kwargs) = self._prepare_prompt(
                 self.seq_group_metadata_list)
        else:
            (input_tokens, input_positions,
             attn_metadata) = self._prepare_decode(
                 self.seq_group_metadata_list)
135
            seq_lens = None
136
137
138
139
140
141
142
143
144
145
146
147

        return self.model_input_cls(
            input_tokens=input_tokens,
            input_positions=input_positions,
            attn_metadata=attn_metadata,
            multi_modal_kwargs=multi_modal_kwargs,
            # query_lens is not needed if chunked prefill is not
            # supported. Since CPU worker doesn't support chunked prefill
            # just use seq_lens instead.
            seq_lens=seq_lens,
            query_lens=seq_lens,
        )
148

149
150
    def _compute_multi_modal_input(self, seq_group: SequenceGroupMetadata,
                                   seq_data: SequenceData, computed_len: int,
151
                                   mm_processor_kwargs: Dict[str, Any]):
152
153
154
155
156
157
158
159
160

        # NOTE: mm_data only includes the subset of multi-modal items that
        # intersect with the current prefill positions.
        mm_data, placeholder_maps = MultiModalPlaceholderMap.from_seq_group(
            seq_group, range(computed_len, len(seq_data.get_token_ids())))

        if not mm_data:
            return

161
        mm_kwargs = self.multi_modal_input_mapper(mm_data, mm_processor_kwargs)
162
163
164

        # special processing for mrope position deltas.
        mrope_positions = None
165
        if self.runner.model_config.uses_mrope:
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
            image_grid_thw = mm_kwargs.get("image_grid_thw", None)
            video_grid_thw = mm_kwargs.get("video_grid_thw", None)
            assert image_grid_thw is not None or video_grid_thw is not None, (
                "mrope embedding type requires multi-modal input mapper "
                "returns 'image_grid_thw' or 'video_grid_thw'.")

            hf_config = self.runner.model_config.hf_config
            token_ids = seq_data.get_token_ids()

            mrope_positions, mrope_position_delta = \
                MRotaryEmbedding.get_input_positions(
                    token_ids,
                    image_grid_thw=image_grid_thw,
                    video_grid_thw=video_grid_thw,
                    image_token_id=hf_config.image_token_id,
                    video_token_id=hf_config.video_token_id,
                    vision_start_token_id=hf_config.vision_start_token_id,
                    vision_end_token_id=hf_config.vision_end_token_id,
                    spatial_merge_size=hf_config.vision_config.
                    spatial_merge_size,
                    context_len=computed_len,
                )
            seq_data.mrope_position_delta = mrope_position_delta
189
        return mm_kwargs, placeholder_maps, mrope_positions
190

191
192
193
    def _prepare_prompt(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
194
    ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int],
195
               BatchedTensorInputs]:
196
197
198
        assert len(seq_group_metadata_list) > 0
        input_tokens: List[int] = []
        input_positions: List[int] = []
199
200
        input_mrope_positions: List[List[int]] = [[] for _ in range(3)]

201
        slot_mapping: List[int] = []
202
        seq_lens: List[int] = []
203
        multi_modal_inputs_list: List[MultiModalInputs] = []
204
205
206
        multi_modal_placeholder_maps: Dict[
            str,
            MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
207
208
209
210
211
212
213
214
215
216

        for seq_group_metadata in seq_group_metadata_list:
            assert seq_group_metadata.is_prompt
            seq_ids = list(seq_group_metadata.seq_data.keys())
            assert len(seq_ids) == 1
            seq_id = seq_ids[0]

            seq_data = seq_group_metadata.seq_data[seq_id]
            prompt_tokens = seq_data.get_token_ids()
            computed_len = seq_data.get_num_computed_tokens()
217
            seq_len = len(prompt_tokens)
218

219
            seq_lens.append(seq_len)  # Prompt token num
220
221
            input_tokens.extend(prompt_tokens)  # Token ids

222
            mrope_positions = None
223
224
225
226
            if seq_group_metadata.multi_modal_data:
                mm_kwargs, placeholder_maps, mrope_positions = self \
                    ._compute_multi_modal_input(
                        seq_group_metadata, seq_data, computed_len,
227
                    seq_group_metadata.mm_processor_kwargs)
228
                multi_modal_inputs_list.append(mm_kwargs)
229
230
231
                for modality, placeholder_map in placeholder_maps.items():
                    multi_modal_placeholder_maps[modality].extend(
                        placeholder_map)
232

233
234
235
            # Token position ids
            # NOTE(woosuk): Here we assume that the first token in the prompt
            # is always the first token in the sequence.
236
237
238
239
240
            if mrope_positions:
                for idx in range(3):
                    input_mrope_positions[idx].extend(mrope_positions[idx])
            else:
                input_positions.extend(list(range(computed_len, seq_len)))
241

242
243
244
            # Compute the slot mapping.
            block_table = seq_group_metadata.block_tables[seq_id]
            # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
245
            # where start_idx is max(0, seq_len - sliding_window).
246
247
248
249
250
            # For example, if the prompt len is 10, sliding window is 8, and
            # block size is 4, the first two tokens are masked and the slot
            # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
            start_idx = 0
            if self.sliding_window is not None:
251
                start_idx = max(0, seq_len - self.sliding_window)
252

253
            for i in range(computed_len, seq_len):
254
255
256
257
258
259
260
261
262
263
                if i < start_idx:
                    slot_mapping.append(_PAD_SLOT_ID)
                    continue

                block_number = block_table[i //
                                           self.block_size]  # type: ignore
                block_offset = i % self.block_size  # type: ignore
                slot = block_number * self.block_size + block_offset
                slot_mapping.append(slot)

264
265
266
267
268
        if any(input_mrope_positions):
            input_positions = None  # type: ignore
        else:
            input_mrope_positions = None  # type: ignore

269
270
271
272
273
        num_prompt_tokens = len(input_tokens)

        input_tokens = torch.tensor(input_tokens,
                                    dtype=torch.long,
                                    device=self.device)  # type: ignore
274
275
        input_positions = torch.tensor(input_positions
                                       or input_mrope_positions,
276
277
278
279
280
                                       dtype=torch.long,
                                       device=self.device)  # type: ignore
        slot_mapping = torch.tensor(slot_mapping,
                                    dtype=torch.long,
                                    device=self.device)  # type: ignore
281
282
283
284
285
        placeholder_index_maps = {
            modality: placeholder_map.index_map()
            for modality, placeholder_map in
            multi_modal_placeholder_maps.items()
        }
286
287
288

        attn_metadata = self.attn_backend.make_metadata(
            is_prompt=True,
289
            seq_lens=seq_lens,
290
291
            seq_lens_tensor=torch.tensor([]),
            max_decode_seq_len=0,
292
            num_prefills=len(seq_lens),
293
294
295
296
            num_prefill_tokens=num_prompt_tokens,
            num_decode_tokens=0,
            block_tables=torch.tensor([]),
            slot_mapping=slot_mapping,
297
            multi_modal_placeholder_index_maps=placeholder_index_maps,
298
        )
299

300
        multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list)
301

302
        return (input_tokens, input_positions, attn_metadata, seq_lens,
303
                multi_modal_kwargs)
304
305
306
307
308
309
310
311

    def _prepare_decode(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
    ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata]:
        assert len(seq_group_metadata_list) > 0
        input_tokens: List[int] = []
        input_positions: List[int] = []
312
        input_mrope_positions: List[List[int]] = [[] for _ in range(3)]
313
        slot_mapping: List[int] = []
314
        seq_lens: List[int] = []
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
        block_tables: List[List[int]] = []

        for seq_group_metadata in seq_group_metadata_list:
            assert not seq_group_metadata.is_prompt
            assert seq_group_metadata.token_chunk_size == 1

            seq_ids = list(seq_group_metadata.seq_data.keys())

            for seq_id in seq_ids:
                seq_data = seq_group_metadata.seq_data[seq_id]
                generation_token = seq_data.get_last_token_id()
                input_tokens.append(generation_token)

                seq_len = seq_data.get_len()
                position = seq_len - 1
330
331
332
333
334
335
336
337
338
339
340
                if seq_data.mrope_position_delta is not None:
                    context_len = seq_data.get_num_computed_tokens()
                    next_pos = MRotaryEmbedding.get_next_input_positions(
                        seq_data.mrope_position_delta,
                        context_len,
                        seq_len,
                    )
                    for idx in range(3):
                        input_mrope_positions[idx].extend(next_pos[idx])
                else:
                    input_positions.append(position)
341

342
                seq_len = seq_len if self.sliding_window is None else min(
343
                    seq_len, self.sliding_window)
344
                seq_lens.append(seq_len)
345
346
347
348
349
350
351
352
353
354
355
356
357

                block_table = seq_group_metadata.block_tables[seq_id]
                block_number = block_table[position // self.block_size]
                block_offset = position % self.block_size
                slot = block_number * self.block_size + block_offset
                slot_mapping.append(slot)

                if self.sliding_window is not None:
                    sliding_window_blocks = (self.sliding_window //
                                             self.block_size)
                    block_table = block_table[-sliding_window_blocks:]
                block_tables.append(block_table)

358
359
360
361
362
        if any(input_mrope_positions):
            input_positions = None  # type: ignore
        else:
            input_mrope_positions = None  # type: ignore

363
        max_decode_seq_len = max(seq_lens)
364
365
366
367

        input_tokens = torch.tensor(input_tokens,
                                    dtype=torch.long,
                                    device=self.device)
368
369
        input_positions = torch.tensor(input_positions
                                       or input_mrope_positions,
370
371
372
373
374
                                       dtype=torch.long,
                                       device=self.device)
        slot_mapping = torch.tensor(slot_mapping,
                                    dtype=torch.long,
                                    device=self.device)
375
376
377
        seq_lens_tensor = torch.tensor(seq_lens,
                                       dtype=torch.int,
                                       device=self.device)
378
379
380
381
382
383
384
385
386
387
388

        block_tables = make_tensor_with_pad(
            block_tables,
            pad=0,
            dtype=torch.int,
            device=self.device,
        )

        attn_metadata = self.attn_backend.make_metadata(
            is_prompt=False,
            slot_mapping=slot_mapping,
389
            multi_modal_placeholder_index_maps=None,
390
391
            seq_lens=seq_lens,
            seq_lens_tensor=seq_lens_tensor,
392
            max_decode_seq_len=max_decode_seq_len,
393
394
395
396
397
398
399
400
401
402
403
            num_prefill_tokens=0,
            num_decode_tokens=len(input_tokens),
            num_prefills=0,
            block_tables=block_tables,
        )
        return (
            input_tokens,
            input_positions,
            attn_metadata,
        )

404
405
406
407
408
409
410
411

class CPUModelRunner(ModelRunnerBase[ModelInputForCPU]):
    _model_input_cls: Type[ModelInputForCPUWithSamplingMetadata] = (
        ModelInputForCPUWithSamplingMetadata)
    _builder_cls: Type[ModelInputForCPUBuilder] = ModelInputForCPUBuilder

    def __init__(
        self,
412
        vllm_config: VllmConfig,
413
414
415
416
417
        kv_cache_dtype: Optional[str] = "auto",
        is_driver_worker: bool = False,
        *args,
        **kwargs,
    ):
418
        ModelRunnerBase.__init__(self, vllm_config)
419
420
        # Currently, CPU worker doesn't support chunked prefill.
        assert self.scheduler_config.chunked_prefill_enabled is False
421
422
423
        model_config = self.model_config
        cache_config = self.cache_config

424
425
426
427
428
429
430
431
432
433
434
435
        self.is_driver_worker = is_driver_worker

        self.device = self.device_config.device

        self.kv_cache_dtype = kv_cache_dtype
        self.sliding_window = model_config.get_sliding_window()
        self.block_size = cache_config.block_size
        self.attn_backend = get_attn_backend(
            self.model_config.get_head_size(),
            self.model_config.dtype,
            self.kv_cache_dtype,
            self.block_size,
436
            self.model_config.is_attention_free,
437
438
439
440
441
442
443
444
445
446
447
448
        )

        # Multi-modal data support
        self.mm_registry = MULTIMODAL_REGISTRY
        self.multi_modal_input_mapper = self.mm_registry \
            .create_input_mapper(self.model_config)
        self.mm_registry.init_mm_limits_per_prompt(self.model_config)

        # Lazy initialization.
        self.model: nn.Module  # Set after init_Model

    def load_model(self) -> None:
449
        self.model = get_model(vllm_config=self.vllm_config)
450

451
452
453
    def make_model_input_from_broadcasted_tensor_dict(
        self,
        tensor_dict: Dict[str, Any],
454
455
    ) -> ModelInputForCPUWithSamplingMetadata:
        return ModelInputForCPUWithSamplingMetadata.from_broadcasted_tensor_dict(  # noqa: E501
456
457
458
459
            tensor_dict,
            attn_backend=self.attn_backend,
        )

460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
    def _prepare_model_input_tensors(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
        finished_requests_ids: Optional[List[str]] = None
    ) -> ModelInputForCPUWithSamplingMetadata:
        """Helper method to prepare the model input based on a given sequence
        group. Prepares metadata needed for the base model forward pass but not
        metadata for possible additional steps, e.g., sampling.

        """
        builder = self._builder_cls(weakref.proxy(self), finished_requests_ids)
        for seq_group_metadata in seq_group_metadata_list:
            builder.add_seq_group(seq_group_metadata)

        return builder.build()  # type: ignore

476
    def prepare_model_input(
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
        virtual_engine: int = 0,
        finished_requests_ids: Optional[List[str]] = None
    ) -> ModelInputForCPUWithSamplingMetadata:
        """Prepare the model input based on a given sequence group, including
        metadata for the sampling step.

        """
        model_input = self._prepare_model_input_tensors(
            seq_group_metadata_list, finished_requests_ids)
        # Sampling metadata is only required for the final pp group
        generators = self.get_generators(finished_requests_ids)
        sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list,
                                                     model_input.seq_lens,
                                                     model_input.query_lens,
                                                     self.device,
                                                     pin_memory=False,
                                                     generators=generators)

        return dataclasses.replace(model_input,
                                   sampling_metadata=sampling_metadata,
                                   virtual_engine=virtual_engine)
500

501
    @torch.no_grad()
502
503
    def execute_model(
        self,
504
        model_input: ModelInputForCPUWithSamplingMetadata,
505
        kv_caches: List[torch.Tensor],
506
        intermediate_tensors: Optional[IntermediateTensors] = None,
507
508
509
510
511
512
        num_steps: int = 1,
    ) -> Optional[List[SamplerOutput]]:
        if num_steps > 1:
            raise ValueError(
                "CPU worker does not support multi-step execution.")

513
514
        model_executable = self.model
        execute_model_kwargs = {
515
516
517
518
519
520
521
522
523
524
            "input_ids":
            model_input.input_tokens,
            "positions":
            model_input.input_positions,
            "kv_caches":
            kv_caches,
            "attn_metadata":
            model_input.attn_metadata,
            **MultiModalInputs.as_kwargs(model_input.multi_modal_kwargs or {},
                                         device=self.device),
525
526
            "intermediate_tensors":
            intermediate_tensors,
527
528
529
530
531
        }

        hidden_states = model_executable(**execute_model_kwargs)

        # Compute the logits.
532
533
        logits = self.model.compute_logits(hidden_states,
                                           model_input.sampling_metadata)
534
535

        # Only perform sampling in the driver worker.
536
        if not self.is_driver_worker:
537
            return []
538
539
540
541

        # Sample the next token.
        output = self.model.sample(
            logits=logits,
542
            sampling_metadata=model_input.sampling_metadata,
543
        )
544
        return [output]