Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
5cbdccd1
Unverified
Commit
5cbdccd1
authored
Oct 26, 2024
by
Mengqing Cao
Committed by
GitHub
Oct 26, 2024
Browse files
[Hardware][openvino] is_openvino --> current_platform.is_openvino (#9716)
parent
067e77f9
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
69 additions
and
38 deletions
+69
-38
tests/kernels/test_attention_selector.py
tests/kernels/test_attention_selector.py
+2
-1
vllm/attention/selector.py
vllm/attention/selector.py
+2
-2
vllm/config.py
vllm/config.py
+2
-2
vllm/executor/openvino_executor.py
vllm/executor/openvino_executor.py
+7
-13
vllm/model_executor/model_loader/openvino.py
vllm/model_executor/model_loader/openvino.py
+2
-2
vllm/platforms/__init__.py
vllm/platforms/__init__.py
+10
-0
vllm/platforms/interface.py
vllm/platforms/interface.py
+4
-0
vllm/platforms/openvino.py
vllm/platforms/openvino.py
+31
-0
vllm/utils.py
vllm/utils.py
+1
-10
vllm/worker/openvino_worker.py
vllm/worker/openvino_worker.py
+8
-8
No files found.
tests/kernels/test_attention_selector.py
View file @
5cbdccd1
...
...
@@ -30,7 +30,8 @@ def test_env(name: str, device: str, monkeypatch):
False
)
assert
backend
.
name
==
"ROCM_FLASH"
elif
device
==
"openvino"
:
with
patch
(
"vllm.attention.selector.is_openvino"
,
return_value
=
True
):
with
patch
(
"vllm.attention.selector.current_platform.is_openvino"
,
return_value
=
True
):
backend
=
which_attn_to_use
(
16
,
torch
.
float16
,
torch
.
float16
,
16
,
False
)
assert
backend
.
name
==
"OPENVINO"
...
...
vllm/attention/selector.py
View file @
5cbdccd1
...
...
@@ -10,7 +10,7 @@ import vllm.envs as envs
from
vllm.attention.backends.abstract
import
AttentionBackend
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.utils
import
STR_BACKEND_ENV_VAR
,
is_hip
,
is_openvino
from
vllm.utils
import
STR_BACKEND_ENV_VAR
,
is_hip
logger
=
init_logger
(
__name__
)
...
...
@@ -193,7 +193,7 @@ def which_attn_to_use(
logger
.
info
(
"Cannot use %s backend on CPU."
,
selected_backend
)
return
_Backend
.
TORCH_SDPA
if
is_openvino
():
if
current_platform
.
is_openvino
():
if
selected_backend
!=
_Backend
.
OPENVINO
:
logger
.
info
(
"Cannot use %s backend on OpenVINO."
,
selected_backend
)
return
_Backend
.
OPENVINO
...
...
vllm/config.py
View file @
5cbdccd1
...
...
@@ -17,7 +17,7 @@ from vllm.transformers_utils.config import (ConfigFormat, get_config,
get_hf_image_processor_config
,
get_hf_text_config
)
from
vllm.utils
import
(
GiB_bytes
,
cuda_device_count_stateless
,
get_cpu_memory
,
is_hip
,
is_openvino
,
print_warning_once
)
is_hip
,
print_warning_once
)
if
TYPE_CHECKING
:
from
ray.util.placement_group
import
PlacementGroup
...
...
@@ -1117,7 +1117,7 @@ class DeviceConfig:
self
.
device_type
=
"cuda"
elif
current_platform
.
is_neuron
():
self
.
device_type
=
"neuron"
elif
is_openvino
():
elif
current_platform
.
is_openvino
():
self
.
device_type
=
"openvino"
elif
current_platform
.
is_tpu
():
self
.
device_type
=
"tpu"
...
...
vllm/executor/openvino_executor.py
View file @
5cbdccd1
...
...
@@ -10,6 +10,7 @@ from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.platforms
import
current_platform
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.utils
import
(
GiB_bytes
,
get_distributed_init_method
,
get_ip
,
get_open_port
,
make_async
)
...
...
@@ -17,14 +18,6 @@ from vllm.utils import (GiB_bytes, get_distributed_init_method, get_ip,
logger
=
init_logger
(
__name__
)
def
is_openvino_cpu
()
->
bool
:
return
"CPU"
in
envs
.
VLLM_OPENVINO_DEVICE
def
is_openvino_gpu
()
->
bool
:
return
"GPU"
in
envs
.
VLLM_OPENVINO_DEVICE
class
OpenVINOExecutor
(
ExecutorBase
):
uses_ray
:
bool
=
False
...
...
@@ -32,7 +25,8 @@ class OpenVINOExecutor(ExecutorBase):
def
_init_executor
(
self
)
->
None
:
assert
self
.
device_config
.
device_type
==
"openvino"
assert
self
.
lora_config
is
None
,
"OpenVINO backend doesn't support LoRA"
assert
is_openvino_cpu
()
or
is_openvino_gpu
(),
\
assert
current_platform
.
is_openvino_cpu
()
or
\
current_platform
.
is_openvino_gpu
(),
\
"OpenVINO backend supports only CPU and GPU devices"
self
.
ov_core
=
ov
.
Core
()
...
...
@@ -163,7 +157,7 @@ def _verify_and_get_model_config(config: ModelConfig) -> ModelConfig:
def
_verify_and_get_cache_config
(
ov_core
:
ov
.
Core
,
config
:
CacheConfig
)
->
CacheConfig
:
if
envs
.
VLLM_OPENVINO_CPU_KV_CACHE_PRECISION
==
"u8"
:
if
not
is_openvino_cpu
():
if
not
current_platform
.
is_openvino_cpu
():
logger
.
info
(
"VLLM_OPENVINO_CPU_KV_CACHE_PRECISION is"
"ignored for GPU, f16 data type will be used."
)
config
.
cache_dtype
=
ov
.
Type
.
f16
...
...
@@ -172,7 +166,7 @@ def _verify_and_get_cache_config(ov_core: ov.Core,
"VLLM_OPENVINO_CPU_KV_CACHE_PRECISION env var."
)
config
.
cache_dtype
=
ov
.
Type
.
u8
else
:
if
is_openvino_cpu
():
if
current_platform
.
is_openvino_cpu
():
ov_device
=
envs
.
VLLM_OPENVINO_DEVICE
inference_precision
=
ov_core
.
get_property
(
ov_device
,
hints
.
inference_precision
)
...
...
@@ -183,7 +177,7 @@ def _verify_and_get_cache_config(ov_core: ov.Core,
else
:
config
.
cache_dtype
=
ov
.
Type
.
f16
if
is_openvino_cpu
():
if
current_platform
.
is_openvino_cpu
():
if
config
.
block_size
!=
32
:
logger
.
info
(
f
"OpenVINO CPU optimal block size is 32, overriding currently set
{
config
.
block_size
}
"
# noqa: G004, E501
...
...
@@ -198,7 +192,7 @@ def _verify_and_get_cache_config(ov_core: ov.Core,
kv_cache_space
=
envs
.
VLLM_OPENVINO_KVCACHE_SPACE
if
kv_cache_space
>=
0
:
if
kv_cache_space
==
0
and
is_openvino_cpu
():
if
kv_cache_space
==
0
and
current_platform
.
is_openvino_cpu
():
config
.
openvino_kvcache_space_bytes
=
4
*
GiB_bytes
# type: ignore
logger
.
warning
(
"Environment variable VLLM_OPENVINO_KVCACHE_SPACE (GB) "
...
...
vllm/model_executor/model_loader/openvino.py
View file @
5cbdccd1
...
...
@@ -12,12 +12,12 @@ from torch import nn
import
vllm.envs
as
envs
from
vllm.attention.backends.openvino
import
OpenVINOAttentionMetadata
from
vllm.config
import
DeviceConfig
,
ModelConfig
from
vllm.executor.openvino_executor
import
is_openvino_cpu
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.logits_processor
import
(
LogitsProcessor
,
_prune_hidden_states
)
from
vllm.model_executor.layers.sampler
import
Sampler
,
SamplerOutput
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.platforms
import
current_platform
logger
=
init_logger
(
__name__
)
...
...
@@ -136,7 +136,7 @@ class OpenVINOCasualLM(nn.Module):
ov_device
=
envs
.
VLLM_OPENVINO_DEVICE
paged_attention_transformation
(
pt_model
.
model
)
_modify_cache_parameters
(
pt_model
.
model
,
kv_cache_dtype
,
is_openvino_cpu
())
current_platform
.
is_openvino_cpu
())
ov_compiled
=
ov_core
.
compile_model
(
pt_model
.
model
,
ov_device
)
self
.
ov_request
=
ov_compiled
.
create_infer_request
()
...
...
vllm/platforms/__init__.py
View file @
5cbdccd1
...
...
@@ -65,6 +65,13 @@ try:
except
ImportError
:
pass
is_openvino
=
False
try
:
from
importlib.metadata
import
version
is_openvino
=
"openvino"
in
version
(
"vllm"
)
except
Exception
:
pass
if
is_tpu
:
# people might install pytorch built with cuda but run on tpu
# so we need to check tpu first
...
...
@@ -85,6 +92,9 @@ elif is_cpu:
elif
is_neuron
:
from
.neuron
import
NeuronPlatform
current_platform
=
NeuronPlatform
()
elif
is_openvino
:
from
.openvino
import
OpenVinoPlatform
current_platform
=
OpenVinoPlatform
()
else
:
current_platform
=
UnspecifiedPlatform
()
...
...
vllm/platforms/interface.py
View file @
5cbdccd1
...
...
@@ -11,6 +11,7 @@ class PlatformEnum(enum.Enum):
XPU
=
enum
.
auto
()
CPU
=
enum
.
auto
()
NEURON
=
enum
.
auto
()
OPENVINO
=
enum
.
auto
()
UNSPECIFIED
=
enum
.
auto
()
...
...
@@ -52,6 +53,9 @@ class Platform:
def
is_neuron
(
self
)
->
bool
:
return
self
.
_enum
==
PlatformEnum
.
NEURON
def
is_openvino
(
self
)
->
bool
:
return
self
.
_enum
==
PlatformEnum
.
OPENVINO
def
is_cuda_alike
(
self
)
->
bool
:
"""Stateless version of :func:`torch.cuda.is_available`."""
return
self
.
_enum
in
(
PlatformEnum
.
CUDA
,
PlatformEnum
.
ROCM
)
...
...
vllm/platforms/openvino.py
0 → 100644
View file @
5cbdccd1
import
torch
import
vllm.envs
as
envs
from
vllm.utils
import
print_warning_once
from
.interface
import
Platform
,
PlatformEnum
class
OpenVinoPlatform
(
Platform
):
_enum
=
PlatformEnum
.
OPENVINO
@
classmethod
def
get_device_name
(
self
,
device_id
:
int
=
0
)
->
str
:
return
"openvino"
@
classmethod
def
inference_mode
(
self
):
return
torch
.
inference_mode
(
mode
=
True
)
@
classmethod
def
is_openvino_cpu
(
self
)
->
bool
:
return
"CPU"
in
envs
.
VLLM_OPENVINO_DEVICE
@
classmethod
def
is_openvino_gpu
(
self
)
->
bool
:
return
"GPU"
in
envs
.
VLLM_OPENVINO_DEVICE
@
classmethod
def
is_pin_memory_available
(
self
)
->
bool
:
print_warning_once
(
"Pin memory is not supported on OpenViNO."
)
return
False
vllm/utils.py
View file @
5cbdccd1
...
...
@@ -318,15 +318,6 @@ def is_hip() -> bool:
return
torch
.
version
.
hip
is
not
None
@
lru_cache
(
maxsize
=
None
)
def
is_openvino
()
->
bool
:
from
importlib.metadata
import
PackageNotFoundError
,
version
try
:
return
"openvino"
in
version
(
"vllm"
)
except
PackageNotFoundError
:
return
False
@
lru_cache
(
maxsize
=
None
)
def
get_max_shared_memory_bytes
(
gpu
:
int
=
0
)
->
int
:
"""Returns the maximum shared memory per thread block in bytes."""
...
...
@@ -757,7 +748,7 @@ def is_pin_memory_available() -> bool:
elif
current_platform
.
is_neuron
():
print_warning_once
(
"Pin memory is not supported on Neuron."
)
return
False
elif
current_platform
.
is_cpu
()
or
is_openvino
():
elif
current_platform
.
is_cpu
()
or
current_platform
.
is_openvino
():
return
False
return
True
...
...
vllm/worker/openvino_worker.py
View file @
5cbdccd1
...
...
@@ -13,12 +13,12 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
from
vllm.distributed
import
(
broadcast_tensor_dict
,
ensure_model_parallel_initialized
,
init_distributed_environment
)
from
vllm.executor.openvino_executor
import
is_openvino_cpu
from
vllm.inputs
import
INPUT_REGISTRY
from
vllm.logger
import
init_logger
from
vllm.model_executor
import
set_random_seed
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.platforms
import
current_platform
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
ExecuteModelRequest
,
SequenceGroupMetadata
from
vllm.worker.openvino_model_runner
import
OpenVINOModelRunner
...
...
@@ -99,7 +99,7 @@ class OpenVINOCacheEngine:
num_blocks
,
self
.
block_size
,
self
.
num_kv_heads
,
self
.
head_size
)[
1
:]
kv_cache
:
List
[
Tuple
[
ov
.
Tensor
,
ov
.
Tensor
]]
=
[]
if
is_openvino_cpu
():
if
current_platform
.
is_openvino_cpu
():
for
_
in
range
(
self
.
num_layers
):
key_blocks
=
ov
.
Tensor
(
self
.
cache_config
.
cache_dtype
,
k_block_shape
)
...
...
@@ -141,7 +141,7 @@ class OpenVINOCacheEngine:
if
num_blocks
==
0
:
return
swap_cache
assert
not
is_openvino_cpu
(),
\
assert
not
current_platform
.
is_openvino_cpu
(),
\
"CPU device isn't supposed to have swap cache"
# Update key_cache shape:
...
...
@@ -285,7 +285,7 @@ class OpenVINOWorker(LoraNotSupportedWorkerBase):
cache_block_size
=
self
.
get_cache_block_size_bytes
()
kvcache_space_bytes
=
self
.
cache_config
.
openvino_kvcache_space_bytes
if
is_openvino_cpu
():
if
current_platform
.
is_openvino_cpu
():
num_device_blocks
=
int
(
kvcache_space_bytes
//
cache_block_size
)
num_swap_blocks
=
0
else
:
...
...
@@ -322,7 +322,7 @@ class OpenVINOWorker(LoraNotSupportedWorkerBase):
num_device_blocks
=
num_gpu_blocks
num_swap_blocks
=
num_cpu_blocks
if
is_openvino_cpu
():
if
current_platform
.
is_openvino_cpu
():
assert
(
num_swap_blocks
==
0
),
f
"
{
type
(
self
)
}
does not support swappable cache for CPU"
...
...
@@ -366,7 +366,7 @@ class OpenVINOWorker(LoraNotSupportedWorkerBase):
assert
self
.
kv_cache
is
not
None
# Populate the cache to warmup the memory
if
is_openvino_cpu
():
if
current_platform
.
is_openvino_cpu
():
for
key_cache
,
value_cache
in
self
.
kv_cache
:
key_cache
.
data
[:]
=
0
value_cache
.
data
[:]
=
0
...
...
@@ -414,7 +414,7 @@ class OpenVINOWorker(LoraNotSupportedWorkerBase):
blocks_to_swap_in
=
data
[
"blocks_to_swap_in"
]
blocks_to_swap_out
=
data
[
"blocks_to_swap_out"
]
if
is_openvino_cpu
():
if
current_platform
.
is_openvino_cpu
():
assert
len
(
execute_model_req
.
blocks_to_swap_in
)
==
0
assert
len
(
execute_model_req
.
blocks_to_swap_out
)
==
0
else
:
...
...
@@ -466,7 +466,7 @@ class OpenVINOWorker(LoraNotSupportedWorkerBase):
def
profile_run
(
self
)
->
int
:
ov_device
=
envs
.
VLLM_OPENVINO_DEVICE
assert
not
is_openvino_cpu
(),
\
assert
not
current_platform
.
is_openvino_cpu
(),
\
"CPU device isn't supposed to use profile run."
import
openvino.properties.device
as
device
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment