"vscode:/vscode.git/clone" did not exist on "77f0d465d0a666b65dd877ec462f024a980dd55c"
Unverified Commit fcf2e3d7 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

[Bugfix] Fix OpenVINO model runner (#12750)

parent 58b218d7
...@@ -140,3 +140,7 @@ class OpenVINOAttentionMetadata: ...@@ -140,3 +140,7 @@ class OpenVINOAttentionMetadata:
# `model_executable`. # `model_executable`.
multi_modal_placeholder_index_maps: Optional[Dict[ multi_modal_placeholder_index_maps: Optional[Dict[
str, MultiModalPlaceholderMap.IndexMap]] str, MultiModalPlaceholderMap.IndexMap]]
# Enable/disable KV scales calculation. This is so that we can disable the
# calculation until after prefill and cuda graph capture.
enable_kv_scales_calculation: bool
...@@ -13,7 +13,7 @@ from torch import nn ...@@ -13,7 +13,7 @@ from torch import nn
import vllm.envs as envs import vllm.envs as envs
from vllm.attention.backends.openvino import OpenVINOAttentionMetadata from vllm.attention.backends.openvino import OpenVINOAttentionMetadata
from vllm.config import DeviceConfig, ModelConfig from vllm.config import ModelConfig, VllmConfig, set_current_vllm_config
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.logits_processor import (LogitsProcessor, from vllm.model_executor.layers.logits_processor import (LogitsProcessor,
_prune_hidden_states) _prune_hidden_states)
...@@ -103,7 +103,6 @@ class OpenVINOCausalLM(nn.Module): ...@@ -103,7 +103,6 @@ class OpenVINOCausalLM(nn.Module):
self, self,
ov_core: ov.Core, ov_core: ov.Core,
model_config: ModelConfig, model_config: ModelConfig,
device_config: DeviceConfig,
kv_cache_dtype: ov.Type, kv_cache_dtype: ov.Type,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -187,8 +186,7 @@ class OpenVINOCausalLM(nn.Module): ...@@ -187,8 +186,7 @@ class OpenVINOCausalLM(nn.Module):
def get_model( def get_model(
model_config: ModelConfig, vllm_config: VllmConfig,
device_config: DeviceConfig,
kv_cache_dtype: ov.Type, kv_cache_dtype: ov.Type,
**kwargs, **kwargs,
) -> torch.nn.Module: ) -> torch.nn.Module:
...@@ -201,5 +199,6 @@ def get_model( ...@@ -201,5 +199,6 @@ def get_model(
"be added in the future. If this is important to you, " "be added in the future. If this is important to you, "
"please open an issue on github.") "please open an issue on github.")
return OpenVINOCausalLM(ov_core, model_config, device_config, with set_current_vllm_config(vllm_config):
return OpenVINOCausalLM(ov_core, vllm_config.model_config,
kv_cache_dtype) kv_cache_dtype)
...@@ -54,15 +54,13 @@ class OpenVINOModelRunner(ModelRunnerBase): ...@@ -54,15 +54,13 @@ class OpenVINOModelRunner(ModelRunnerBase):
): ):
self.ov_core = ov_core self.ov_core = ov_core
ModelRunnerBase.__init__(self, vllm_config=vllm_config) ModelRunnerBase.__init__(self, vllm_config=vllm_config)
cache_config = self.cache_config
model_config = self.model_config
self.is_driver_worker = is_driver_worker self.is_driver_worker = is_driver_worker
self.device = self.device_config.device self.device = self.device_config.device
self.kv_cache_dtype = kv_cache_dtype self.kv_cache_dtype = kv_cache_dtype
self.sliding_window = model_config.get_sliding_window() self.sliding_window = self.model_config.get_sliding_window()
self.block_size = cache_config.block_size self.block_size = self.cache_config.block_size
self.attn_backend = get_attn_backend( self.attn_backend = get_attn_backend(
self.model_config.get_head_size(), self.model_config.get_head_size(),
...@@ -81,8 +79,7 @@ class OpenVINOModelRunner(ModelRunnerBase): ...@@ -81,8 +79,7 @@ class OpenVINOModelRunner(ModelRunnerBase):
self.model: nn.Module # Set after init_Model self.model: nn.Module # Set after init_Model
def load_model(self) -> None: def load_model(self) -> None:
self.model = get_model(model_config=self.model_config, self.model = get_model(vllm_config=self.vllm_config,
device_config=self.device_config,
kv_cache_dtype=self.kv_cache_dtype, kv_cache_dtype=self.kv_cache_dtype,
ov_core=self.ov_core) ov_core=self.ov_core)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment