Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
fcf2e3d7
"vscode:/vscode.git/clone" did not exist on "77f0d465d0a666b65dd877ec462f024a980dd55c"
Unverified
Commit
fcf2e3d7
authored
Feb 05, 2025
by
Harry Mellor
Committed by
GitHub
Feb 04, 2025
Browse files
[Bugfix] Fix OpenVINO model runner (#12750)
parent
58b218d7
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
12 additions
and
12 deletions
+12
-12
vllm/attention/backends/openvino.py
vllm/attention/backends/openvino.py
+4
-0
vllm/model_executor/model_loader/openvino.py
vllm/model_executor/model_loader/openvino.py
+5
-6
vllm/worker/openvino_model_runner.py
vllm/worker/openvino_model_runner.py
+3
-6
No files found.
vllm/attention/backends/openvino.py
View file @
fcf2e3d7
...
@@ -140,3 +140,7 @@ class OpenVINOAttentionMetadata:
...
@@ -140,3 +140,7 @@ class OpenVINOAttentionMetadata:
# `model_executable`.
# `model_executable`.
multi_modal_placeholder_index_maps
:
Optional
[
Dict
[
multi_modal_placeholder_index_maps
:
Optional
[
Dict
[
str
,
MultiModalPlaceholderMap
.
IndexMap
]]
str
,
MultiModalPlaceholderMap
.
IndexMap
]]
# Enable/disable KV scales calculation. This is so that we can disable the
# calculation until after prefill and cuda graph capture.
enable_kv_scales_calculation
:
bool
vllm/model_executor/model_loader/openvino.py
View file @
fcf2e3d7
...
@@ -13,7 +13,7 @@ from torch import nn
...
@@ -13,7 +13,7 @@ from torch import nn
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.attention.backends.openvino
import
OpenVINOAttentionMetadata
from
vllm.attention.backends.openvino
import
OpenVINOAttentionMetadata
from
vllm.config
import
DeviceConfig
,
ModelC
onfig
from
vllm.config
import
ModelConfig
,
VllmConfig
,
set_current_vllm_c
onfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.logits_processor
import
(
LogitsProcessor
,
from
vllm.model_executor.layers.logits_processor
import
(
LogitsProcessor
,
_prune_hidden_states
)
_prune_hidden_states
)
...
@@ -103,7 +103,6 @@ class OpenVINOCausalLM(nn.Module):
...
@@ -103,7 +103,6 @@ class OpenVINOCausalLM(nn.Module):
self
,
self
,
ov_core
:
ov
.
Core
,
ov_core
:
ov
.
Core
,
model_config
:
ModelConfig
,
model_config
:
ModelConfig
,
device_config
:
DeviceConfig
,
kv_cache_dtype
:
ov
.
Type
,
kv_cache_dtype
:
ov
.
Type
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
...
@@ -187,8 +186,7 @@ class OpenVINOCausalLM(nn.Module):
...
@@ -187,8 +186,7 @@ class OpenVINOCausalLM(nn.Module):
def
get_model
(
def
get_model
(
model_config
:
ModelConfig
,
vllm_config
:
VllmConfig
,
device_config
:
DeviceConfig
,
kv_cache_dtype
:
ov
.
Type
,
kv_cache_dtype
:
ov
.
Type
,
**
kwargs
,
**
kwargs
,
)
->
torch
.
nn
.
Module
:
)
->
torch
.
nn
.
Module
:
...
@@ -201,5 +199,6 @@ def get_model(
...
@@ -201,5 +199,6 @@ def get_model(
"be added in the future. If this is important to you, "
"be added in the future. If this is important to you, "
"please open an issue on github."
)
"please open an issue on github."
)
return
OpenVINOCausalLM
(
ov_core
,
model_config
,
device_config
,
with
set_current_vllm_config
(
vllm_config
):
return
OpenVINOCausalLM
(
ov_core
,
vllm_config
.
model_config
,
kv_cache_dtype
)
kv_cache_dtype
)
vllm/worker/openvino_model_runner.py
View file @
fcf2e3d7
...
@@ -54,15 +54,13 @@ class OpenVINOModelRunner(ModelRunnerBase):
...
@@ -54,15 +54,13 @@ class OpenVINOModelRunner(ModelRunnerBase):
):
):
self
.
ov_core
=
ov_core
self
.
ov_core
=
ov_core
ModelRunnerBase
.
__init__
(
self
,
vllm_config
=
vllm_config
)
ModelRunnerBase
.
__init__
(
self
,
vllm_config
=
vllm_config
)
cache_config
=
self
.
cache_config
model_config
=
self
.
model_config
self
.
is_driver_worker
=
is_driver_worker
self
.
is_driver_worker
=
is_driver_worker
self
.
device
=
self
.
device_config
.
device
self
.
device
=
self
.
device_config
.
device
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
sliding_window
=
model_config
.
get_sliding_window
()
self
.
sliding_window
=
self
.
model_config
.
get_sliding_window
()
self
.
block_size
=
cache_config
.
block_size
self
.
block_size
=
self
.
cache_config
.
block_size
self
.
attn_backend
=
get_attn_backend
(
self
.
attn_backend
=
get_attn_backend
(
self
.
model_config
.
get_head_size
(),
self
.
model_config
.
get_head_size
(),
...
@@ -81,8 +79,7 @@ class OpenVINOModelRunner(ModelRunnerBase):
...
@@ -81,8 +79,7 @@ class OpenVINOModelRunner(ModelRunnerBase):
self
.
model
:
nn
.
Module
# Set after init_Model
self
.
model
:
nn
.
Module
# Set after init_Model
def
load_model
(
self
)
->
None
:
def
load_model
(
self
)
->
None
:
self
.
model
=
get_model
(
model_config
=
self
.
model_config
,
self
.
model
=
get_model
(
vllm_config
=
self
.
vllm_config
,
device_config
=
self
.
device_config
,
kv_cache_dtype
=
self
.
kv_cache_dtype
,
kv_cache_dtype
=
self
.
kv_cache_dtype
,
ov_core
=
self
.
ov_core
)
ov_core
=
self
.
ov_core
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment