openvino_model_runner.py 13.5 KB
Newer Older
1
from typing import List, NamedTuple, Optional, Tuple
2
3
4
5
6
7
8
9

import openvino as ov
import torch
from torch import nn

from vllm.attention import get_attn_backend
from vllm.attention.backends.openvino import OpenVINOAttentionMetadata
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
10
11
                         ModelConfig, MultiModalConfig, ParallelConfig,
                         SchedulerConfig)
12
13
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
14
from vllm.model_executor.layers.sampler import SamplerOutput
15
from vllm.model_executor.model_loader.openvino import get_model
16
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
17
                             MultiModalInputs)
18
from vllm.sequence import SequenceGroupMetadata
19
20
21
22
23
24
25
26
27
28

logger = init_logger(__name__)


class ModelInput(NamedTuple):
    input_tokens: torch.Tensor
    input_positions: torch.Tensor
    attn_metadata: Optional[OpenVINOAttentionMetadata]
    seq_lens: List[int]
    query_lens: List[int]
29
    multi_modal_kwargs: BatchedTensorInputs
30
31
32
33
34
35
36
37

    @classmethod
    def empty(cls, device):
        return ModelInput(input_tokens=torch.empty(0, device=device),
                          input_positions=torch.empty(0, device=device),
                          attn_metadata=None,
                          seq_lens=[],
                          query_lens=[],
38
                          multi_modal_kwargs={})
39
40
41
42
43
44


class OpenVINOModelRunner:

    def __init__(
        self,
45
        ov_core: ov.Core,
46
47
48
49
50
51
52
        model_config: ModelConfig,
        parallel_config: ParallelConfig,
        scheduler_config: SchedulerConfig,
        device_config: DeviceConfig,
        cache_config: CacheConfig,
        load_config: LoadConfig,
        lora_config: Optional[LoRAConfig],
53
        multimodal_config: Optional[MultiModalConfig],
54
55
56
57
58
        kv_cache_dtype: Optional[str] = "auto",
        is_driver_worker: bool = False,
        *args,
        **kwargs,
    ):
59
        self.ov_core = ov_core
60
61
62
63
64
65
        self.model_config = model_config
        self.parallel_config = parallel_config
        self.scheduler_config = scheduler_config
        self.device_config = device_config
        self.cache_config = cache_config
        self.lora_config = lora_config
66
        self.multimodal_config = multimodal_config
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
        self.load_config = load_config
        self.is_driver_worker = is_driver_worker

        self.device = self.device_config.device

        self.kv_cache_dtype = kv_cache_dtype
        self.sliding_window = model_config.get_sliding_window()
        self.block_size = cache_config.block_size

        self.attn_backend = get_attn_backend(
            self.model_config.get_num_attention_heads(self.parallel_config),
            self.model_config.get_head_size(),
            self.model_config.get_num_kv_heads(self.parallel_config),
            self.model_config.get_sliding_window(),
            self.model_config.dtype,
            self.kv_cache_dtype,
            self.block_size,
        )

86
87
88
89
        # Multi-modal data support
        self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \
            .create_input_mapper(self.model_config)

90
91
92
93
        # Lazy initialization.
        self.model: nn.Module  # Set after init_Model

    def load_model(self) -> None:
94
95
96
97
        self.model = get_model(model_config=self.model_config,
                               device_config=self.device_config,
                               kv_cache_dtype=self.kv_cache_dtype,
                               ov_core=self.ov_core)
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118

    def _prepare_model_input(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
    ) -> ModelInput:
        """Prepare the model input based on a given sequence group.

        The API assumes seq_group_metadata_list is sorted by prefill -> decode.

        The result tensors and data structure also batches input in prefill
        -> decode order. For example,

        - input_tokens[:num_prefill_tokens] contains prefill tokens.
        - input_tokens[num_prefill_tokens:] contains decode tokens.
        """
        input_tokens: List[int] = []
        input_positions: List[int] = []

        seq_lens: List[int] = []
        past_lens: List[int] = []
        query_lens: List[int] = []
119
120
        multi_modal_inputs_list: List[MultiModalInputs] = []

121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
        subsequence_begins: List[int] = []
        block_indices: List[int] = []
        block_indices_begins: List[int] = []

        # initialize beginning of prefix sums
        subsequence_begins.append(0)
        block_indices_begins.append(0)

        if len(seq_group_metadata_list) == 0:
            return ModelInput.empty(self.device)

        for seq_group_metadata in seq_group_metadata_list:
            seq_ids = list(seq_group_metadata.seq_data.keys())
            is_prompt = seq_group_metadata.is_prompt

            for seq_id in seq_ids:
                computed_block_nums = seq_group_metadata.computed_block_nums
                if (self.scheduler_config is not None
                        and self.scheduler_config.chunked_prefill_enabled
                        and not (computed_block_nums is None
                                 or computed_block_nums == [])):
                    raise RuntimeError(
                        "chunked prefill cannot be used with prefix caching "
                        "now.")

                seq_data = seq_group_metadata.seq_data[seq_id]
                if is_prompt:
                    computed_len = seq_data.get_num_computed_tokens()
                else:
                    # get_num_computed_tokens is incorrect for spec decoding.
                    # So, we should have a special logic here.
                    # TODO(sang): Fix it.
                    computed_len = seq_data.get_len() - 1

                seq_len = min(
                    seq_data.get_len(),
                    computed_len + seq_group_metadata.token_chunk_size,
                )
                if is_prompt:
                    tokens = seq_data.get_token_ids()[computed_len:seq_len]
                else:
                    # Optimization. get_token_ids requires the entire copy of
                    # tokens.
                    tokens = [seq_data.get_last_token_id()]

                # Prefix cache was hit.
                # Prefix is not supported with sliding_window
                prefix_cache_hit = (computed_block_nums is not None
                                    and len(computed_block_nums) > 0
                                    and self.sliding_window is None
                                    and is_prompt)

173
174
                mm_data = seq_group_metadata.multi_modal_data
                if mm_data:
175
176
177
178
179
                    mm_kwargs = self.multi_modal_input_mapper(
                        mm_data,
                        mm_processor_kwargs=seq_group_metadata.
                        mm_processor_kwargs,
                    )
180
181
                    multi_modal_inputs_list.append(mm_kwargs)

182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
                block_table = seq_group_metadata.block_tables[seq_id]
                # TODO(sang): Combine chunked prefill and prefix caching by
                # only allowing multiple of block_size chunk size.
                # NOTE: This only works for oooooooxxx style attention.
                if prefix_cache_hit:
                    assert computed_block_nums is not None
                    computed_len = len(computed_block_nums) * self.block_size
                    tokens = tokens[computed_len:]
                elif (self.scheduler_config.chunked_prefill_enabled
                      or not is_prompt):
                    if seq_group_metadata.block_tables is not None:
                        # chunked prefill or decode
                        block_table = seq_group_metadata.block_tables[seq_id]
                        if self.sliding_window is not None:
                            # chunked prefill doesn't support sliding window.
                            assert not self.scheduler_config.chunked_prefill_enabled  # noqa: E501
                            sliding_window_blocks = (self.sliding_window //
                                                     self.block_size)
                            block_table = block_table[-sliding_window_blocks:]
                    else:
                        # Only happens when memory profiling runs.
                        block_table = []
                else:
                    # prompt phase w/o prefix_caching, chunked_prefill
                    pass

                block_indices.extend(block_table)
                block_indices_begins.append(block_indices_begins[-1] +
                                            len(block_table))

                # TODO(sang): This is a hack to make sliding window work with
                # paged attn. We can remove it if we make paged attn kernel
                # to properly handle slinding window attn.
                if self.sliding_window is not None and not is_prompt:
                    seq_len = min(seq_len, self.sliding_window)
                    computed_len = seq_len - 1

                seq_lens.append(seq_len)

                query_len = seq_len - computed_len
                query_lens.append(query_len)

                input_tokens.extend(tokens)
                input_positions.extend(list(range(computed_len, seq_len)))

                past_lens.append(computed_len)
                subsequence_begins.append(subsequence_begins[-1] + query_len)

                if is_prompt:
                    assert len(seq_ids) == 1
                else:
                    assert (
                        query_len == 1
                    ), "seq_len: {}, computed_len: {}, query_len: {}".format(
                        seq_len, computed_len, query_len)

        max_query_len = max(query_lens)
        assert max_query_len > 0, "query_lens: {}".format(query_lens)

        input_tokens = torch.tensor(input_tokens,
                                    dtype=torch.long,
                                    device=self.device)  # type: ignore
        input_positions = torch.tensor(input_positions,
                                       dtype=torch.long,
                                       device=self.device)  # type: ignore

        past_lens_tensor = torch.tensor(past_lens,
                                        dtype=torch.int32,
                                        device=self.device)  # type: ignore
        subsequence_begins_tensor = torch.tensor(
            subsequence_begins, dtype=torch.int32,
            device=self.device)  # type: ignore
        block_indices_tensor = torch.tensor(block_indices,
                                            dtype=torch.int32,
                                            device=self.device)  # type: ignore
        block_indices_begins_tensor = torch.tensor(
            block_indices_begins, dtype=torch.int32,
            device=self.device)  # type: ignore

        max_context_len = max(seq_lens)
        max_context_len_tensor = torch.tensor(
            max_context_len, dtype=torch.int32,
            device=self.device)  # type: ignore

        attn_metadata = self.attn_backend.make_openvino_metadata(
            past_lens=past_lens_tensor,
            subsequence_begins=subsequence_begins_tensor,
            block_indices=block_indices_tensor,
            block_indices_begins=block_indices_begins_tensor,
            max_context_len=max_context_len_tensor,
        )
273

274
        multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list)
275

276
277
278
279
280
281
        return ModelInput(
            input_tokens,
            input_positions,
            attn_metadata,
            seq_lens,
            query_lens,
282
            multi_modal_kwargs=multi_modal_kwargs,
283
284
285
286
287
288
        )

    def prepare_input_tensors(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
    ) -> Tuple[torch.Tensor, torch.Tensor, OpenVINOAttentionMetadata,
289
               SamplingMetadata, BatchedTensorInputs]:
290
291
292
293
294
295
296
        # Prepare input tensors.
        (
            input_tokens,
            input_positions,
            attn_metadata,
            seq_lens,
            query_lens,
297
            multi_modal_kwargs,
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
        ) = self._prepare_model_input(seq_group_metadata_list)

        sampling_metadata = SamplingMetadata.prepare(
            seq_group_metadata_list,
            seq_lens,
            query_lens,
            self.device,
            pin_memory=False,
        )

        return (
            input_tokens,
            input_positions,
            attn_metadata,
            sampling_metadata,
313
            multi_modal_kwargs,
314
315
316
317
318
319
320
321
322
323
324
325
326
        )

    @torch.inference_mode()
    def execute_model(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
        kv_caches: List[Tuple["ov.Tensor", "ov.Tensor"]],
    ) -> Optional[SamplerOutput]:
        (
            input_tokens,
            input_positions,
            attn_metadata,
            sampling_metadata,
327
            multi_modal_kwargs,
328
329
330
331
        ) = self.prepare_input_tensors(seq_group_metadata_list)

        model_executable = self.model
        execute_model_kwargs = {
332
333
334
335
336
337
338
339
340
341
            "input_ids":
            input_tokens,
            "positions":
            input_positions,
            "kv_caches":
            kv_caches,
            "attn_metadata":
            attn_metadata,
            **MultiModalInputs.as_kwargs(multi_modal_kwargs or {},
                                         device=self.device),
342
343
344
345
346
347
348
349
350
351
352
353
354
        }

        hidden_states = model_executable(**execute_model_kwargs)

        # Compute the logits.
        logits = self.model.compute_logits(hidden_states, sampling_metadata)

        # Sample the next token.
        output = self.model.sample(
            logits=logits,
            sampling_metadata=sampling_metadata,
        )
        return output