openvino_model_runner.py 14.3 KB
Newer Older
1
2
from collections import defaultdict
from typing import Dict, List, NamedTuple, Optional, Tuple
3
4
5
6
7
8
9
10

import openvino as ov
import torch
from torch import nn

from vllm.attention import get_attn_backend
from vllm.attention.backends.openvino import OpenVINOAttentionMetadata
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
11
12
                         ModelConfig, MultiModalConfig, ParallelConfig,
                         SchedulerConfig)
13
14
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
15
from vllm.model_executor.layers.sampler import SamplerOutput
16
from vllm.model_executor.model_loader.openvino import get_model
17
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
18
                             MultiModalInputs, MultiModalPlaceholderMap)
19
from vllm.sequence import SequenceGroupMetadata
20
21
22
23
24
25
26
27
28
29

logger = init_logger(__name__)


class ModelInput(NamedTuple):
    input_tokens: torch.Tensor
    input_positions: torch.Tensor
    attn_metadata: Optional[OpenVINOAttentionMetadata]
    seq_lens: List[int]
    query_lens: List[int]
30
    multi_modal_kwargs: BatchedTensorInputs
31
32
33
34
35
36
37
38

    @classmethod
    def empty(cls, device):
        return ModelInput(input_tokens=torch.empty(0, device=device),
                          input_positions=torch.empty(0, device=device),
                          attn_metadata=None,
                          seq_lens=[],
                          query_lens=[],
39
                          multi_modal_kwargs={})
40
41
42
43
44
45


class OpenVINOModelRunner:

    def __init__(
        self,
46
        ov_core: ov.Core,
47
48
49
50
51
52
53
        model_config: ModelConfig,
        parallel_config: ParallelConfig,
        scheduler_config: SchedulerConfig,
        device_config: DeviceConfig,
        cache_config: CacheConfig,
        load_config: LoadConfig,
        lora_config: Optional[LoRAConfig],
54
        multimodal_config: Optional[MultiModalConfig],
55
56
57
58
59
        kv_cache_dtype: Optional[str] = "auto",
        is_driver_worker: bool = False,
        *args,
        **kwargs,
    ):
60
        self.ov_core = ov_core
61
62
63
64
65
66
        self.model_config = model_config
        self.parallel_config = parallel_config
        self.scheduler_config = scheduler_config
        self.device_config = device_config
        self.cache_config = cache_config
        self.lora_config = lora_config
67
        self.multimodal_config = multimodal_config
68
69
70
71
72
73
74
75
76
77
78
79
80
81
        self.load_config = load_config
        self.is_driver_worker = is_driver_worker

        self.device = self.device_config.device

        self.kv_cache_dtype = kv_cache_dtype
        self.sliding_window = model_config.get_sliding_window()
        self.block_size = cache_config.block_size

        self.attn_backend = get_attn_backend(
            self.model_config.get_head_size(),
            self.model_config.dtype,
            self.kv_cache_dtype,
            self.block_size,
82
            self.model_config.is_attention_free,
83
84
        )

85
86
87
88
        # Multi-modal data support
        self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \
            .create_input_mapper(self.model_config)

89
90
91
92
        # Lazy initialization.
        self.model: nn.Module  # Set after init_Model

    def load_model(self) -> None:
93
94
95
96
        self.model = get_model(model_config=self.model_config,
                               device_config=self.device_config,
                               kv_cache_dtype=self.kv_cache_dtype,
                               ov_core=self.ov_core)
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117

    def _prepare_model_input(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
    ) -> ModelInput:
        """Prepare the model input based on a given sequence group.

        The API assumes seq_group_metadata_list is sorted by prefill -> decode.

        The result tensors and data structure also batches input in prefill
        -> decode order. For example,

        - input_tokens[:num_prefill_tokens] contains prefill tokens.
        - input_tokens[num_prefill_tokens:] contains decode tokens.
        """
        input_tokens: List[int] = []
        input_positions: List[int] = []

        seq_lens: List[int] = []
        past_lens: List[int] = []
        query_lens: List[int] = []
118
        multi_modal_inputs_list: List[MultiModalInputs] = []
119
120
121
        multi_modal_placeholder_maps: Dict[
            str,
            MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
122

123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
        subsequence_begins: List[int] = []
        block_indices: List[int] = []
        block_indices_begins: List[int] = []

        # initialize beginning of prefix sums
        subsequence_begins.append(0)
        block_indices_begins.append(0)

        if len(seq_group_metadata_list) == 0:
            return ModelInput.empty(self.device)

        for seq_group_metadata in seq_group_metadata_list:
            seq_ids = list(seq_group_metadata.seq_data.keys())
            is_prompt = seq_group_metadata.is_prompt

            for seq_id in seq_ids:
                computed_block_nums = seq_group_metadata.computed_block_nums
                if (self.scheduler_config is not None
                        and self.scheduler_config.chunked_prefill_enabled
                        and not (computed_block_nums is None
                                 or computed_block_nums == [])):
                    raise RuntimeError(
                        "chunked prefill cannot be used with prefix caching "
                        "now.")

                seq_data = seq_group_metadata.seq_data[seq_id]
                if is_prompt:
                    computed_len = seq_data.get_num_computed_tokens()
                else:
                    # get_num_computed_tokens is incorrect for spec decoding.
                    # So, we should have a special logic here.
                    # TODO(sang): Fix it.
                    computed_len = seq_data.get_len() - 1

                seq_len = min(
                    seq_data.get_len(),
                    computed_len + seq_group_metadata.token_chunk_size,
                )
                if is_prompt:
                    tokens = seq_data.get_token_ids()[computed_len:seq_len]
                else:
                    # Optimization. get_token_ids requires the entire copy of
                    # tokens.
                    tokens = [seq_data.get_last_token_id()]

                # Prefix cache was hit.
                # Prefix is not supported with sliding_window
                prefix_cache_hit = (computed_block_nums is not None
                                    and len(computed_block_nums) > 0
                                    and self.sliding_window is None
                                    and is_prompt)

                block_table = seq_group_metadata.block_tables[seq_id]
                # TODO(sang): Combine chunked prefill and prefix caching by
                # only allowing multiple of block_size chunk size.
                # NOTE: This only works for oooooooxxx style attention.
                if prefix_cache_hit:
                    assert computed_block_nums is not None
                    computed_len = len(computed_block_nums) * self.block_size
                    tokens = tokens[computed_len:]
                elif (self.scheduler_config.chunked_prefill_enabled
                      or not is_prompt):
                    if seq_group_metadata.block_tables is not None:
                        # chunked prefill or decode
                        block_table = seq_group_metadata.block_tables[seq_id]
                        if self.sliding_window is not None:
                            # chunked prefill doesn't support sliding window.
                            assert not self.scheduler_config.chunked_prefill_enabled  # noqa: E501
                            sliding_window_blocks = (self.sliding_window //
                                                     self.block_size)
                            block_table = block_table[-sliding_window_blocks:]
                    else:
                        # Only happens when memory profiling runs.
                        block_table = []
                else:
                    # prompt phase w/o prefix_caching, chunked_prefill
                    pass

                block_indices.extend(block_table)
                block_indices_begins.append(block_indices_begins[-1] +
                                            len(block_table))

                # TODO(sang): This is a hack to make sliding window work with
                # paged attn. We can remove it if we make paged attn kernel
                # to properly handle slinding window attn.
                if self.sliding_window is not None and not is_prompt:
                    seq_len = min(seq_len, self.sliding_window)
                    computed_len = seq_len - 1

                seq_lens.append(seq_len)

                query_len = seq_len - computed_len
                query_lens.append(query_len)

                input_tokens.extend(tokens)
218
219
                positions_range = range(computed_len, seq_len)
                input_positions.extend(list(positions_range))
220
221
222
223
224
225
226
227
228
229
230
231

                past_lens.append(computed_len)
                subsequence_begins.append(subsequence_begins[-1] + query_len)

                if is_prompt:
                    assert len(seq_ids) == 1
                else:
                    assert (
                        query_len == 1
                    ), "seq_len: {}, computed_len: {}, query_len: {}".format(
                        seq_len, computed_len, query_len)

232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
                if seq_group_metadata.multi_modal_data:
                    # NOTE: mm_data only includes the subset of multi-modal
                    # items that intersect with the current prefill positions.
                    mm_data, placeholder_maps = MultiModalPlaceholderMap \
                        .from_seq_group(seq_group_metadata, positions_range)

                    mm_kwargs = self.multi_modal_input_mapper(
                        mm_data,
                        mm_processor_kwargs=seq_group_metadata.
                        mm_processor_kwargs)
                    multi_modal_inputs_list.append(mm_kwargs)

                    for modality, placeholder_map in placeholder_maps.items():
                        multi_modal_placeholder_maps[modality].extend(
                            placeholder_map, )

248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
        max_query_len = max(query_lens)
        assert max_query_len > 0, "query_lens: {}".format(query_lens)

        input_tokens = torch.tensor(input_tokens,
                                    dtype=torch.long,
                                    device=self.device)  # type: ignore
        input_positions = torch.tensor(input_positions,
                                       dtype=torch.long,
                                       device=self.device)  # type: ignore

        past_lens_tensor = torch.tensor(past_lens,
                                        dtype=torch.int32,
                                        device=self.device)  # type: ignore
        subsequence_begins_tensor = torch.tensor(
            subsequence_begins, dtype=torch.int32,
            device=self.device)  # type: ignore
        block_indices_tensor = torch.tensor(block_indices,
                                            dtype=torch.int32,
                                            device=self.device)  # type: ignore
        block_indices_begins_tensor = torch.tensor(
            block_indices_begins, dtype=torch.int32,
            device=self.device)  # type: ignore

        max_context_len = max(seq_lens)
        max_context_len_tensor = torch.tensor(
            max_context_len, dtype=torch.int32,
            device=self.device)  # type: ignore

276
277
278
279
280
281
        placeholder_index_maps = {
            modality: placeholder_map.index_map()
            for modality, placeholder_map in
            multi_modal_placeholder_maps.items()
        }

282
283
284
285
286
287
        attn_metadata = self.attn_backend.make_openvino_metadata(
            past_lens=past_lens_tensor,
            subsequence_begins=subsequence_begins_tensor,
            block_indices=block_indices_tensor,
            block_indices_begins=block_indices_begins_tensor,
            max_context_len=max_context_len_tensor,
288
            multi_modal_placeholder_index_maps=placeholder_index_maps,
289
        )
290

291
        multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list)
292

293
294
295
296
297
298
        return ModelInput(
            input_tokens,
            input_positions,
            attn_metadata,
            seq_lens,
            query_lens,
299
            multi_modal_kwargs=multi_modal_kwargs,
300
301
302
303
304
305
        )

    def prepare_input_tensors(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
    ) -> Tuple[torch.Tensor, torch.Tensor, OpenVINOAttentionMetadata,
306
               SamplingMetadata, BatchedTensorInputs]:
307
308
309
310
311
312
313
        # Prepare input tensors.
        (
            input_tokens,
            input_positions,
            attn_metadata,
            seq_lens,
            query_lens,
314
            multi_modal_kwargs,
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
        ) = self._prepare_model_input(seq_group_metadata_list)

        sampling_metadata = SamplingMetadata.prepare(
            seq_group_metadata_list,
            seq_lens,
            query_lens,
            self.device,
            pin_memory=False,
        )

        return (
            input_tokens,
            input_positions,
            attn_metadata,
            sampling_metadata,
330
            multi_modal_kwargs,
331
332
333
334
335
336
337
338
339
340
341
342
343
        )

    @torch.inference_mode()
    def execute_model(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
        kv_caches: List[Tuple["ov.Tensor", "ov.Tensor"]],
    ) -> Optional[SamplerOutput]:
        (
            input_tokens,
            input_positions,
            attn_metadata,
            sampling_metadata,
344
            multi_modal_kwargs,
345
346
347
348
        ) = self.prepare_input_tensors(seq_group_metadata_list)

        model_executable = self.model
        execute_model_kwargs = {
349
350
351
352
353
354
355
356
357
358
            "input_ids":
            input_tokens,
            "positions":
            input_positions,
            "kv_caches":
            kv_caches,
            "attn_metadata":
            attn_metadata,
            **MultiModalInputs.as_kwargs(multi_modal_kwargs or {},
                                         device=self.device),
359
360
361
362
363
364
365
366
367
368
369
370
371
        }

        hidden_states = model_executable(**execute_model_kwargs)

        # Compute the logits.
        logits = self.model.compute_logits(hidden_states, sampling_metadata)

        # Sample the next token.
        output = self.model.sample(
            logits=logits,
            sampling_metadata=sampling_metadata,
        )
        return output