cpu_model_runner.py 18.7 KB
Newer Older
1
2
import dataclasses
import weakref
3
from dataclasses import dataclass
4
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union
5
6

import torch
7
from torch import nn
8
9

from vllm.attention import AttentionMetadata, get_attn_backend
10
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
11
12
                         ModelConfig, ParallelConfig, PromptAdapterConfig,
                         SchedulerConfig)
13
14
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
15
from vllm.model_executor.layers.sampler import SamplerOutput
16
from vllm.model_executor.model_loader import get_model
17
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
18
                             MultiModalInputs)
19
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
20
from vllm.utils import STR_NOT_IMPL_ENC_DEC_ERR_STRS, make_tensor_with_pad
21
from vllm.worker.model_runner_base import (
22
    ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
23
24
25
26
27
28
29
    _add_attn_metadata_broadcastable_dict,
    _add_sampling_metadata_broadcastable_dict,
    _init_attn_metadata_from_tensor_dict,
    _init_sampling_metadata_from_tensor_dict)

if TYPE_CHECKING:
    from vllm.attention.backends.abstract import AttentionBackend
30
31
32
33
34
35

logger = init_logger(__name__)

_PAD_SLOT_ID = -1


36
@dataclass(frozen=True)
37
class ModelInputForCPU(ModelRunnerInputBase):
38
    """
39
    Base class contains metadata needed for the base model forward pass on CPU
40
41
42
43
    """
    input_tokens: Optional[torch.Tensor] = None
    input_positions: Optional[torch.Tensor] = None
    attn_metadata: Optional["AttentionMetadata"] = None
44
    multi_modal_kwargs: Optional[BatchedTensorInputs] = None
45
    virtual_engine: Optional[int] = None
46
47
    seq_lens: Optional[List[int]] = None
    query_lens: Optional[List[int]] = None
48
49
50
51
52
53
54
55
56

    def as_broadcastable_tensor_dict(
            self) -> Dict[str, Union[int, torch.Tensor]]:
        tensor_dict = {
            "input_tokens": self.input_tokens,
            "input_positions": self.input_positions,
            "multi_modal_kwargs": self.multi_modal_kwargs,
        }
        _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
57

58
59
60
61
        return tensor_dict

    @classmethod
    def from_broadcasted_tensor_dict(
62
63
64
65
        cls: Type["ModelInputForCPU"],
        tensor_dict: Dict[str, Any],
        attn_backend: Optional["AttentionBackend"] = None
    ) -> "ModelInputForCPU":
66
67
68
69
70
71
        if attn_backend is not None:
            tensor_dict = _init_attn_metadata_from_tensor_dict(
                attn_backend, tensor_dict)
        return cls(**tensor_dict)


72
73
74
75
76
77
@dataclass(frozen=True)
class ModelInputForCPUWithSamplingMetadata(ModelInputForCPU):
    """
    Used by the ModelRunner.
    """
    sampling_metadata: Optional["SamplingMetadata"] = None
78

79
80
81
82
83
84
85
86
87
    def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
        tensor_dict = {
            "input_tokens": self.input_tokens,
            "input_positions": self.input_positions,
        }
        _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
        _add_sampling_metadata_broadcastable_dict(tensor_dict,
                                                  self.sampling_metadata)
        return tensor_dict
88

89
90
91
92
93
94
95
96
97
98
99
    @classmethod
    def from_broadcasted_tensor_dict(
        cls,
        tensor_dict: Dict[str, Any],
        attn_backend: Optional["AttentionBackend"] = None,
    ) -> "ModelInputForCPUWithSamplingMetadata":
        tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict)
        if attn_backend is not None:
            tensor_dict = _init_attn_metadata_from_tensor_dict(
                attn_backend, tensor_dict)
        return cls(**tensor_dict)
100
101


102
class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
103

104
105
106
107
108
109
110
111
112
113
114
115
    def __init__(self,
                 runner: "CPUModelRunner",
                 finished_requests_ids: Optional[List[str]] = None) -> None:
        super().__init__()
        self.seq_group_metadata_list: List[SequenceGroupMetadata] = []
        self.runner = runner
        self.model_input_cls = self.runner._model_input_cls
        self.attn_backend = self.runner.attn_backend
        self.sliding_window = self.runner.sliding_window
        self.block_size = self.runner.block_size
        self.device = self.runner.device
        self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper
116

117
118
    def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):
        self.seq_group_metadata_list.append(seq_group_metadata)
119

120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
    def build(self) -> ModelInputForCPU:
        multi_modal_kwargs = None
        # NOTE: We assume that all sequences in the group are all prompts or
        # all decodes.
        is_prompt = self.seq_group_metadata_list[0].is_prompt
        # Prepare input tensors.
        if is_prompt:
            (input_tokens, input_positions, attn_metadata, seq_lens,
             multi_modal_kwargs) = self._prepare_prompt(
                 self.seq_group_metadata_list)
        else:
            (input_tokens, input_positions,
             attn_metadata) = self._prepare_decode(
                 self.seq_group_metadata_list)
            seq_lens = []

        return self.model_input_cls(
            input_tokens=input_tokens,
            input_positions=input_positions,
            attn_metadata=attn_metadata,
            multi_modal_kwargs=multi_modal_kwargs,
            # query_lens is not needed if chunked prefill is not
            # supported. Since CPU worker doesn't support chunked prefill
            # just use seq_lens instead.
            seq_lens=seq_lens,
            query_lens=seq_lens,
        )
147
148
149
150

    def _prepare_prompt(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
151
    ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int],
152
               BatchedTensorInputs]:
153
154
155
156
        assert len(seq_group_metadata_list) > 0
        input_tokens: List[int] = []
        input_positions: List[int] = []
        slot_mapping: List[int] = []
157
        seq_lens: List[int] = []
158
        multi_modal_inputs_list: List[MultiModalInputs] = []
159
160
161
162
163
164
165
166
167
168

        for seq_group_metadata in seq_group_metadata_list:
            assert seq_group_metadata.is_prompt
            seq_ids = list(seq_group_metadata.seq_data.keys())
            assert len(seq_ids) == 1
            seq_id = seq_ids[0]

            seq_data = seq_group_metadata.seq_data[seq_id]
            prompt_tokens = seq_data.get_token_ids()
            computed_len = seq_data.get_num_computed_tokens()
169
            seq_len = len(prompt_tokens)
170

171
            seq_lens.append(seq_len)  # Prompt token num
172
173
174
175
176
            input_tokens.extend(prompt_tokens)  # Token ids

            # Token position ids
            # NOTE(woosuk): Here we assume that the first token in the prompt
            # is always the first token in the sequence.
177
            input_positions.extend(list(range(computed_len, seq_len)))
178

179
            if (mm_data := seq_group_metadata.multi_modal_data):
180
                mm_kwargs = self.multi_modal_input_mapper(mm_data)
181
                multi_modal_inputs_list.append(mm_kwargs)
182

183
184
185
            # Compute the slot mapping.
            block_table = seq_group_metadata.block_tables[seq_id]
            # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
186
            # where start_idx is max(0, seq_len - sliding_window).
187
188
189
190
191
            # For example, if the prompt len is 10, sliding window is 8, and
            # block size is 4, the first two tokens are masked and the slot
            # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
            start_idx = 0
            if self.sliding_window is not None:
192
                start_idx = max(0, seq_len - self.sliding_window)
193

194
            for i in range(computed_len, seq_len):
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
                if i < start_idx:
                    slot_mapping.append(_PAD_SLOT_ID)
                    continue

                block_number = block_table[i //
                                           self.block_size]  # type: ignore
                block_offset = i % self.block_size  # type: ignore
                slot = block_number * self.block_size + block_offset
                slot_mapping.append(slot)

        num_prompt_tokens = len(input_tokens)

        input_tokens = torch.tensor(input_tokens,
                                    dtype=torch.long,
                                    device=self.device)  # type: ignore
        input_positions = torch.tensor(input_positions,
                                       dtype=torch.long,
                                       device=self.device)  # type: ignore
        slot_mapping = torch.tensor(slot_mapping,
                                    dtype=torch.long,
                                    device=self.device)  # type: ignore

        attn_metadata = self.attn_backend.make_metadata(
            is_prompt=True,
219
            seq_lens=seq_lens,
220
221
            seq_lens_tensor=torch.tensor([]),
            max_decode_seq_len=0,
222
            num_prefills=len(seq_lens),
223
224
225
226
227
            num_prefill_tokens=num_prompt_tokens,
            num_decode_tokens=0,
            block_tables=torch.tensor([]),
            slot_mapping=slot_mapping,
        )
228

229
        multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list)
230

231
        return (input_tokens, input_positions, attn_metadata, seq_lens,
232
                multi_modal_kwargs)
233
234
235
236
237
238
239
240
241

    def _prepare_decode(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
    ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata]:
        assert len(seq_group_metadata_list) > 0
        input_tokens: List[int] = []
        input_positions: List[int] = []
        slot_mapping: List[int] = []
242
        seq_lens: List[int] = []
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
        block_tables: List[List[int]] = []

        for seq_group_metadata in seq_group_metadata_list:
            assert not seq_group_metadata.is_prompt
            assert seq_group_metadata.token_chunk_size == 1

            seq_ids = list(seq_group_metadata.seq_data.keys())

            for seq_id in seq_ids:
                seq_data = seq_group_metadata.seq_data[seq_id]
                generation_token = seq_data.get_last_token_id()
                input_tokens.append(generation_token)

                seq_len = seq_data.get_len()
                position = seq_len - 1
                input_positions.append(position)

260
                seq_len = seq_len if self.sliding_window is None else min(
261
                    seq_len, self.sliding_window)
262
                seq_lens.append(seq_len)
263
264
265
266
267
268
269
270
271
272
273
274
275

                block_table = seq_group_metadata.block_tables[seq_id]
                block_number = block_table[position // self.block_size]
                block_offset = position % self.block_size
                slot = block_number * self.block_size + block_offset
                slot_mapping.append(slot)

                if self.sliding_window is not None:
                    sliding_window_blocks = (self.sliding_window //
                                             self.block_size)
                    block_table = block_table[-sliding_window_blocks:]
                block_tables.append(block_table)

276
        max_decode_seq_len = max(seq_lens)
277
278
279
280
281
282
283
284
285
286

        input_tokens = torch.tensor(input_tokens,
                                    dtype=torch.long,
                                    device=self.device)
        input_positions = torch.tensor(input_positions,
                                       dtype=torch.long,
                                       device=self.device)
        slot_mapping = torch.tensor(slot_mapping,
                                    dtype=torch.long,
                                    device=self.device)
287
288
289
        seq_lens_tensor = torch.tensor(seq_lens,
                                       dtype=torch.int,
                                       device=self.device)
290
291
292
293
294
295
296
297
298
299
300

        block_tables = make_tensor_with_pad(
            block_tables,
            pad=0,
            dtype=torch.int,
            device=self.device,
        )

        attn_metadata = self.attn_backend.make_metadata(
            is_prompt=False,
            slot_mapping=slot_mapping,
301
302
            seq_lens=seq_lens,
            seq_lens_tensor=seq_lens_tensor,
303
            max_decode_seq_len=max_decode_seq_len,
304
305
306
307
308
309
310
311
312
313
314
            num_prefill_tokens=0,
            num_decode_tokens=len(input_tokens),
            num_prefills=0,
            block_tables=block_tables,
        )
        return (
            input_tokens,
            input_positions,
            attn_metadata,
        )

315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384

class CPUModelRunner(ModelRunnerBase[ModelInputForCPU]):
    _model_input_cls: Type[ModelInputForCPUWithSamplingMetadata] = (
        ModelInputForCPUWithSamplingMetadata)
    _builder_cls: Type[ModelInputForCPUBuilder] = ModelInputForCPUBuilder

    def __init__(
        self,
        model_config: ModelConfig,
        parallel_config: ParallelConfig,
        scheduler_config: SchedulerConfig,
        device_config: DeviceConfig,
        cache_config: CacheConfig,
        load_config: LoadConfig,
        lora_config: Optional[LoRAConfig],
        kv_cache_dtype: Optional[str] = "auto",
        prompt_adapter_config: Optional[PromptAdapterConfig] = None,
        is_driver_worker: bool = False,
        *args,
        **kwargs,
    ):
        self.model_config = model_config
        self.parallel_config = parallel_config
        self.scheduler_config = scheduler_config
        # Currently, CPU worker doesn't support chunked prefill.
        assert self.scheduler_config.chunked_prefill_enabled is False
        self.device_config = device_config
        self.cache_config = cache_config
        self.lora_config = lora_config
        self.prompt_adapter_config = prompt_adapter_config
        self.load_config = load_config
        self.is_driver_worker = is_driver_worker

        self.device = self.device_config.device

        self.kv_cache_dtype = kv_cache_dtype
        self.sliding_window = model_config.get_sliding_window()
        self.block_size = cache_config.block_size
        self.attn_backend = get_attn_backend(
            self.model_config.get_num_attention_heads(self.parallel_config),
            self.model_config.get_head_size(),
            self.model_config.get_num_kv_heads(self.parallel_config),
            self.model_config.get_sliding_window(),
            self.model_config.dtype,
            self.kv_cache_dtype,
            self.block_size,
        )

        # Multi-modal data support
        self.mm_registry = MULTIMODAL_REGISTRY
        self.multi_modal_input_mapper = self.mm_registry \
            .create_input_mapper(self.model_config)
        self.mm_registry.init_mm_limits_per_prompt(self.model_config)

        # Lazy initialization.
        self.model: nn.Module  # Set after init_Model

        if self.model_config.is_encoder_decoder_model:
            raise NotImplementedError(
                STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_CPU'])

    def load_model(self) -> None:
        self.model = get_model(model_config=self.model_config,
                               load_config=self.load_config,
                               device_config=self.device_config,
                               lora_config=self.lora_config,
                               parallel_config=self.parallel_config,
                               scheduler_config=self.scheduler_config,
                               cache_config=self.cache_config)

385
386
387
    def make_model_input_from_broadcasted_tensor_dict(
        self,
        tensor_dict: Dict[str, Any],
388
389
    ) -> ModelInputForCPU:
        return ModelInputForCPU.from_broadcasted_tensor_dict(
390
391
392
393
            tensor_dict,
            attn_backend=self.attn_backend,
        )

394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
    def _prepare_model_input_tensors(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
        finished_requests_ids: Optional[List[str]] = None
    ) -> ModelInputForCPUWithSamplingMetadata:
        """Helper method to prepare the model input based on a given sequence
        group. Prepares metadata needed for the base model forward pass but not
        metadata for possible additional steps, e.g., sampling.

        """
        builder = self._builder_cls(weakref.proxy(self), finished_requests_ids)
        for seq_group_metadata in seq_group_metadata_list:
            builder.add_seq_group(seq_group_metadata)

        return builder.build()  # type: ignore

410
    def prepare_model_input(
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
        virtual_engine: int = 0,
        finished_requests_ids: Optional[List[str]] = None
    ) -> ModelInputForCPUWithSamplingMetadata:
        """Prepare the model input based on a given sequence group, including
        metadata for the sampling step.

        """
        model_input = self._prepare_model_input_tensors(
            seq_group_metadata_list, finished_requests_ids)
        # Sampling metadata is only required for the final pp group
        generators = self.get_generators(finished_requests_ids)
        sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list,
                                                     model_input.seq_lens,
                                                     model_input.query_lens,
                                                     self.device,
                                                     pin_memory=False,
                                                     generators=generators)

        return dataclasses.replace(model_input,
                                   sampling_metadata=sampling_metadata,
                                   virtual_engine=virtual_engine)
434

435
    @torch.no_grad()
436
437
    def execute_model(
        self,
438
        model_input: ModelInputForCPUWithSamplingMetadata,
439
        kv_caches: List[torch.Tensor],
440
        intermediate_tensors: Optional[IntermediateTensors] = None,
441
442
443
444
445
446
        num_steps: int = 1,
    ) -> Optional[List[SamplerOutput]]:
        if num_steps > 1:
            raise ValueError(
                "CPU worker does not support multi-step execution.")

447
448
        model_executable = self.model
        execute_model_kwargs = {
449
450
451
452
453
454
455
456
457
458
            "input_ids":
            model_input.input_tokens,
            "positions":
            model_input.input_positions,
            "kv_caches":
            kv_caches,
            "attn_metadata":
            model_input.attn_metadata,
            **MultiModalInputs.as_kwargs(model_input.multi_modal_kwargs or {},
                                         device=self.device),
459
460
            "intermediate_tensors":
            intermediate_tensors,
461
462
463
464
465
        }

        hidden_states = model_executable(**execute_model_kwargs)

        # Compute the logits.
466
467
        logits = self.model.compute_logits(hidden_states,
                                           model_input.sampling_metadata)
468
469

        # Only perform sampling in the driver worker.
470
        if not self.is_driver_worker:
471
            return []
472
473
474
475

        # Sample the next token.
        output = self.model.sample(
            logits=logits,
476
            sampling_metadata=model_input.sampling_metadata,
477
        )
478
        return [output]