cpu_model_runner.py 14.8 KB
Newer Older
1
from dataclasses import dataclass
2
3
from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple,
                    Type, Union)
4
5

import torch
6
from torch import nn
7
8

from vllm.attention import AttentionMetadata, get_attn_backend
9
10
11
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
                         ModelConfig, ParallelConfig, SchedulerConfig,
                         VisionLanguageConfig)
12
13
14
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.model_loader import get_model
15
16
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors,
                             MultiModalInputs)
17
18
from vllm.sequence import (IntermediateTensors, SamplerOutput,
                           SequenceGroupMetadata)
19
from vllm.utils import make_tensor_with_pad
20
21
22
23
24
25
26
27
28
from vllm.worker.model_runner_base import (
    ModelRunnerBase, ModelRunnerInputBase,
    _add_attn_metadata_broadcastable_dict,
    _add_sampling_metadata_broadcastable_dict,
    _init_attn_metadata_from_tensor_dict,
    _init_sampling_metadata_from_tensor_dict)

if TYPE_CHECKING:
    from vllm.attention.backends.abstract import AttentionBackend
29
30
31
32
33
34

logger = init_logger(__name__)

_PAD_SLOT_ID = -1


35
36
37
38
39
40
41
42
43
@dataclass(frozen=True)
class CPUModelInput(ModelRunnerInputBase):
    """
    Used by the CPUModelRunner.
    """
    input_tokens: Optional[torch.Tensor] = None
    input_positions: Optional[torch.Tensor] = None
    attn_metadata: Optional["AttentionMetadata"] = None
    sampling_metadata: Optional["SamplingMetadata"] = None
44
    multi_modal_kwargs: Optional[Mapping[str, BatchedTensors]] = None
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71

    def as_broadcastable_tensor_dict(
            self) -> Dict[str, Union[int, torch.Tensor]]:
        tensor_dict = {
            "input_tokens": self.input_tokens,
            "input_positions": self.input_positions,
            "multi_modal_kwargs": self.multi_modal_kwargs,
        }
        _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
        _add_sampling_metadata_broadcastable_dict(tensor_dict,
                                                  self.sampling_metadata)
        return tensor_dict

    @classmethod
    def from_broadcasted_tensor_dict(
            cls: Type["CPUModelInput"],
            tensor_dict: Dict[str, Any],
            attn_backend: Optional["AttentionBackend"] = None
    ) -> "CPUModelInput":
        tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict)
        if attn_backend is not None:
            tensor_dict = _init_attn_metadata_from_tensor_dict(
                attn_backend, tensor_dict)
        return cls(**tensor_dict)


class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
72
73
74
75
76
77
78

    def __init__(
        self,
        model_config: ModelConfig,
        parallel_config: ParallelConfig,
        scheduler_config: SchedulerConfig,
        device_config: DeviceConfig,
79
        cache_config: CacheConfig,
80
        load_config: LoadConfig,
81
        lora_config: Optional[LoRAConfig],
82
        vision_language_config: Optional[VisionLanguageConfig],
83
84
85
86
87
88
89
90
        kv_cache_dtype: Optional[str] = "auto",
        is_driver_worker: bool = False,
        *args,
        **kwargs,
    ):
        self.model_config = model_config
        self.parallel_config = parallel_config
        self.scheduler_config = scheduler_config
91
92
        # Currently, CPU worker doesn't support chunked prefill.
        assert self.scheduler_config.chunked_prefill_enabled is False
93
94
        self.device_config = device_config
        self.cache_config = cache_config
95
        self.lora_config = lora_config
96
        self.vision_language_config = vision_language_config
97
        self.load_config = load_config
98
99
100
101
102
        self.is_driver_worker = is_driver_worker

        self.device = self.device_config.device

        self.kv_cache_dtype = kv_cache_dtype
103
104
        self.sliding_window = model_config.get_sliding_window()
        self.block_size = cache_config.block_size
105
106
107
108
109
110
111
112
113
        self.attn_backend = get_attn_backend(
            self.model_config.get_num_attention_heads(self.parallel_config),
            self.model_config.get_head_size(),
            self.model_config.get_num_kv_heads(self.parallel_config),
            self.model_config.get_sliding_window(),
            self.model_config.dtype,
            self.kv_cache_dtype,
            self.block_size,
        )
114

115
116
117
        # Multi-modal data support
        self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \
            .create_input_mapper(self.model_config)
118

119
120
121
        # Lazy initialization.
        self.model: nn.Module  # Set after init_Model

122
    def load_model(self) -> None:
123
124
125
126
127
128
129
        self.model = get_model(
            model_config=self.model_config,
            load_config=self.load_config,
            device_config=self.device_config,
            vision_language_config=self.vision_language_config,
            lora_config=self.lora_config,
            parallel_config=self.parallel_config,
130
131
            scheduler_config=self.scheduler_config,
            cache_config=self.cache_config)
132
133
134
135

    def _prepare_prompt(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
136
137
    ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int],
               Mapping[str, BatchedTensors]]:
138
139
140
141
        assert len(seq_group_metadata_list) > 0
        input_tokens: List[int] = []
        input_positions: List[int] = []
        slot_mapping: List[int] = []
142
        seq_lens: List[int] = []
143
        multi_modal_inputs_list: List[MultiModalInputs] = []
144
145
146
147
148
149
150
151
152
153

        for seq_group_metadata in seq_group_metadata_list:
            assert seq_group_metadata.is_prompt
            seq_ids = list(seq_group_metadata.seq_data.keys())
            assert len(seq_ids) == 1
            seq_id = seq_ids[0]

            seq_data = seq_group_metadata.seq_data[seq_id]
            prompt_tokens = seq_data.get_token_ids()
            computed_len = seq_data.get_num_computed_tokens()
154
            seq_len = len(prompt_tokens)
155

156
            seq_lens.append(seq_len)  # Prompt token num
157
158
159
160
161
            input_tokens.extend(prompt_tokens)  # Token ids

            # Token position ids
            # NOTE(woosuk): Here we assume that the first token in the prompt
            # is always the first token in the sequence.
162
            input_positions.extend(list(range(computed_len, seq_len)))
163

164
            mm_data = seq_group_metadata.multi_modal_data
165
            if mm_data:
166
                mm_kwargs = self.multi_modal_input_mapper(mm_data)
167
                multi_modal_inputs_list.append(mm_kwargs)
168

169
170
171
            # Compute the slot mapping.
            block_table = seq_group_metadata.block_tables[seq_id]
            # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
172
            # where start_idx is max(0, seq_len - sliding_window).
173
174
175
176
177
            # For example, if the prompt len is 10, sliding window is 8, and
            # block size is 4, the first two tokens are masked and the slot
            # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
            start_idx = 0
            if self.sliding_window is not None:
178
                start_idx = max(0, seq_len - self.sliding_window)
179

180
            for i in range(computed_len, seq_len):
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
                if i < start_idx:
                    slot_mapping.append(_PAD_SLOT_ID)
                    continue

                block_number = block_table[i //
                                           self.block_size]  # type: ignore
                block_offset = i % self.block_size  # type: ignore
                slot = block_number * self.block_size + block_offset
                slot_mapping.append(slot)

        num_prompt_tokens = len(input_tokens)

        input_tokens = torch.tensor(input_tokens,
                                    dtype=torch.long,
                                    device=self.device)  # type: ignore
        input_positions = torch.tensor(input_positions,
                                       dtype=torch.long,
                                       device=self.device)  # type: ignore
        slot_mapping = torch.tensor(slot_mapping,
                                    dtype=torch.long,
                                    device=self.device)  # type: ignore

        attn_metadata = self.attn_backend.make_metadata(
            is_prompt=True,
205
206
            seq_lens=seq_lens,
            seq_lens_tensor=None,
207
            max_decode_seq_len=None,
208
            num_prefills=len(seq_lens),
209
210
211
212
213
            num_prefill_tokens=num_prompt_tokens,
            num_decode_tokens=0,
            block_tables=torch.tensor([]),
            slot_mapping=slot_mapping,
        )
214
215
216
217

        multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list,
                                                    device=self.device)

218
        return (input_tokens, input_positions, attn_metadata, seq_lens,
219
                multi_modal_kwargs)
220
221
222
223
224
225
226
227
228

    def _prepare_decode(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
    ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata]:
        assert len(seq_group_metadata_list) > 0
        input_tokens: List[int] = []
        input_positions: List[int] = []
        slot_mapping: List[int] = []
229
        seq_lens: List[int] = []
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
        block_tables: List[List[int]] = []

        for seq_group_metadata in seq_group_metadata_list:
            assert not seq_group_metadata.is_prompt
            assert seq_group_metadata.token_chunk_size == 1

            seq_ids = list(seq_group_metadata.seq_data.keys())

            for seq_id in seq_ids:
                seq_data = seq_group_metadata.seq_data[seq_id]
                generation_token = seq_data.get_last_token_id()
                input_tokens.append(generation_token)

                seq_len = seq_data.get_len()
                position = seq_len - 1
                input_positions.append(position)

247
                seq_len = seq_len if self.sliding_window is None else min(
248
                    seq_len, self.sliding_window)
249
                seq_lens.append(seq_len)
250
251
252
253
254
255
256
257
258
259
260
261
262

                block_table = seq_group_metadata.block_tables[seq_id]
                block_number = block_table[position // self.block_size]
                block_offset = position % self.block_size
                slot = block_number * self.block_size + block_offset
                slot_mapping.append(slot)

                if self.sliding_window is not None:
                    sliding_window_blocks = (self.sliding_window //
                                             self.block_size)
                    block_table = block_table[-sliding_window_blocks:]
                block_tables.append(block_table)

263
        max_decode_seq_len = max(seq_lens)
264
265
266
267
268
269
270
271
272
273

        input_tokens = torch.tensor(input_tokens,
                                    dtype=torch.long,
                                    device=self.device)
        input_positions = torch.tensor(input_positions,
                                       dtype=torch.long,
                                       device=self.device)
        slot_mapping = torch.tensor(slot_mapping,
                                    dtype=torch.long,
                                    device=self.device)
274
275
276
        seq_lens_tensor = torch.tensor(seq_lens,
                                       dtype=torch.int,
                                       device=self.device)
277
278
279
280
281
282
283
284
285
286
287
288
289
290

        max_block_table_len = max(
            len(block_table) for block_table in block_tables)
        block_tables = make_tensor_with_pad(
            block_tables,
            max_len=max_block_table_len,
            pad=0,
            dtype=torch.int,
            device=self.device,
        )

        attn_metadata = self.attn_backend.make_metadata(
            is_prompt=False,
            slot_mapping=slot_mapping,
291
292
            seq_lens=seq_lens,
            seq_lens_tensor=seq_lens_tensor,
293
            max_decode_seq_len=max_decode_seq_len,
294
295
296
297
298
299
300
301
302
303
304
            num_prefill_tokens=0,
            num_decode_tokens=len(input_tokens),
            num_prefills=0,
            block_tables=block_tables,
        )
        return (
            input_tokens,
            input_positions,
            attn_metadata,
        )

305
306
307
308
309
310
311
312
313
314
    def make_model_input_from_broadcasted_tensor_dict(
        self,
        tensor_dict: Dict[str, Any],
    ) -> CPUModelInput:
        return CPUModelInput.from_broadcasted_tensor_dict(
            tensor_dict,
            attn_backend=self.attn_backend,
        )

    def prepare_model_input(
Mor Zusman's avatar
Mor Zusman committed
315
316
317
318
            self,
            seq_group_metadata_list: List[SequenceGroupMetadata],
            virtual_engine: int = 0,
            finished_requests_ids: Optional[List[str]] = None
319
    ) -> CPUModelInput:
320
        multi_modal_kwargs = None
321
322
323
324
325
326
327
328
        # NOTE: We assume that all sequences in the group are all prompts or
        # all decodes.
        is_prompt = seq_group_metadata_list[0].is_prompt
        # Prepare input tensors.
        if is_prompt:
            (input_tokens, input_positions, attn_metadata, seq_lens,
             multi_modal_kwargs
             ) = self._prepare_prompt(seq_group_metadata_list)
329
        else:
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
            (input_tokens, input_positions,
             attn_metadata) = self._prepare_decode(seq_group_metadata_list)
            seq_lens = []
        sampling_metadata = SamplingMetadata.prepare(
            seq_group_metadata_list,
            seq_lens,
            # query_lens is not needed if chunked prefill is not
            # supported. Since CPU worker doesn't support chunked prefill
            # just use seq_lens instead.
            seq_lens,
            self.device,
            pin_memory=False)
        return CPUModelInput(
            input_tokens=input_tokens,
            input_positions=input_positions,
            attn_metadata=attn_metadata,
            sampling_metadata=sampling_metadata,
347
            multi_modal_kwargs=multi_modal_kwargs,
348
        )
349
350
351
352

    @torch.inference_mode()
    def execute_model(
        self,
353
        model_input: CPUModelInput,
354
        kv_caches: List[torch.Tensor],
355
        intermediate_tensors: Optional[IntermediateTensors] = None,
356
357
358
359
360
361
        num_steps: int = 1,
    ) -> Optional[List[SamplerOutput]]:
        if num_steps > 1:
            raise ValueError(
                "CPU worker does not support multi-step execution.")

362
363
        model_executable = self.model
        execute_model_kwargs = {
364
365
            "input_ids": model_input.input_tokens,
            "positions": model_input.input_positions,
366
            "kv_caches": kv_caches,
367
            "attn_metadata": model_input.attn_metadata,
368
            **(model_input.multi_modal_kwargs or {}),
369
370
371
372
373
        }

        hidden_states = model_executable(**execute_model_kwargs)

        # Compute the logits.
374
375
        logits = self.model.compute_logits(hidden_states,
                                           model_input.sampling_metadata)
376
377

        # Only perform sampling in the driver worker.
378
        if not self.is_driver_worker:
379
            return []
380
381
382
383

        # Sample the next token.
        output = self.model.sample(
            logits=logits,
384
            sampling_metadata=model_input.sampling_metadata,
385
        )
386
        return [output]