Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
86a44fb8
Unverified
Commit
86a44fb8
authored
Nov 23, 2024
by
JiHuazhong
Committed by
GitHub
Nov 22, 2024
Browse files
[Platforms] Refactor openvino code (#10573)
Signed-off-by:
statelesshz
<
hzji210@gmail.com
>
parent
4cfe5d2b
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
72 additions
and
78 deletions
+72
-78
vllm/executor/openvino_executor.py
vllm/executor/openvino_executor.py
+3
-78
vllm/platforms/openvino.py
vllm/platforms/openvino.py
+69
-0
No files found.
vllm/executor/openvino_executor.py
View file @
86a44fb8
from
typing
import
List
,
Set
,
Tuple
from
typing
import
List
,
Set
,
Tuple
import
openvino
as
ov
import
openvino
as
ov
import
openvino.properties.hint
as
hints
import
torch
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.config
import
CacheConfig
,
ModelConfig
from
vllm.executor.executor_base
import
ExecutorAsyncBase
,
ExecutorBase
from
vllm.executor.executor_base
import
ExecutorAsyncBase
,
ExecutorBase
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.utils
import
(
GiB_bytes
,
get_distributed_init_method
,
get_ip
,
from
vllm.utils
import
(
get_distributed_init_method
,
get_ip
,
get_open_port
,
get_open_port
,
make_async
)
make_async
)
from
vllm.worker.worker_base
import
WorkerWrapperBase
from
vllm.worker.worker_base
import
WorkerWrapperBase
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -30,11 +27,6 @@ class OpenVINOExecutor(ExecutorBase):
...
@@ -30,11 +27,6 @@ class OpenVINOExecutor(ExecutorBase):
current_platform
.
is_openvino_gpu
(),
\
current_platform
.
is_openvino_gpu
(),
\
"OpenVINO backend supports only CPU and GPU devices"
"OpenVINO backend supports only CPU and GPU devices"
self
.
ov_core
=
ov
.
Core
()
self
.
model_config
=
_verify_and_get_model_config
(
self
.
model_config
)
self
.
cache_config
=
_verify_and_get_cache_config
(
self
.
ov_core
,
self
.
cache_config
)
# Instantiate the worker and load the model to CPU.
# Instantiate the worker and load the model to CPU.
self
.
_init_worker
()
self
.
_init_worker
()
...
@@ -45,7 +37,7 @@ class OpenVINOExecutor(ExecutorBase):
...
@@ -45,7 +37,7 @@ class OpenVINOExecutor(ExecutorBase):
distributed_init_method
=
get_distributed_init_method
(
distributed_init_method
=
get_distributed_init_method
(
get_ip
(),
get_open_port
())
get_ip
(),
get_open_port
())
self
.
driver_worker
=
wrapper
.
init_worker
(
self
.
driver_worker
=
wrapper
.
init_worker
(
ov_core
=
self
.
ov_c
ore
,
ov_core
=
ov
.
C
ore
()
,
vllm_config
=
self
.
vllm_config
,
vllm_config
=
self
.
vllm_config
,
local_rank
=
0
,
local_rank
=
0
,
rank
=
0
,
rank
=
0
,
...
@@ -130,70 +122,3 @@ class OpenVINOExecutorAsync(OpenVINOExecutor, ExecutorAsyncBase):
...
@@ -130,70 +122,3 @@ class OpenVINOExecutorAsync(OpenVINOExecutor, ExecutorAsyncBase):
# OpenVINOExecutor will always be healthy as long as
# OpenVINOExecutor will always be healthy as long as
# it's running.
# it's running.
return
return
def
_verify_and_get_model_config
(
config
:
ModelConfig
)
->
ModelConfig
:
if
config
.
dtype
!=
torch
.
float32
:
logger
.
warning
(
f
"Only float32 dtype is supported on OpenVINO, casting from
{
config
.
dtype
}
."
# noqa: G004, E501
)
config
.
dtype
=
torch
.
float32
if
not
config
.
enforce_eager
:
logger
.
warning
(
"CUDA graph is not supported on OpenVINO backend, fallback to the "
"eager mode."
)
config
.
enforce_eager
=
True
return
config
def
_verify_and_get_cache_config
(
ov_core
:
ov
.
Core
,
config
:
CacheConfig
)
->
CacheConfig
:
if
envs
.
VLLM_OPENVINO_CPU_KV_CACHE_PRECISION
==
"u8"
:
if
not
current_platform
.
is_openvino_cpu
():
logger
.
info
(
"VLLM_OPENVINO_CPU_KV_CACHE_PRECISION is"
"ignored for GPU, f16 data type will be used."
)
config
.
cache_dtype
=
ov
.
Type
.
f16
else
:
logger
.
info
(
"KV cache type is overridden to u8 via "
"VLLM_OPENVINO_CPU_KV_CACHE_PRECISION env var."
)
config
.
cache_dtype
=
ov
.
Type
.
u8
else
:
if
current_platform
.
is_openvino_cpu
():
ov_device
=
envs
.
VLLM_OPENVINO_DEVICE
inference_precision
=
ov_core
.
get_property
(
ov_device
,
hints
.
inference_precision
)
if
inference_precision
==
ov
.
Type
.
bf16
:
config
.
cache_dtype
=
ov
.
Type
.
bf16
else
:
config
.
cache_dtype
=
ov
.
Type
.
f16
else
:
config
.
cache_dtype
=
ov
.
Type
.
f16
if
current_platform
.
is_openvino_cpu
():
if
config
.
block_size
!=
32
:
logger
.
info
(
f
"OpenVINO CPU optimal block size is 32, overriding currently set
{
config
.
block_size
}
"
# noqa: G004, E501
)
config
.
block_size
=
32
else
:
if
config
.
block_size
!=
16
:
logger
.
info
(
f
"OpenVINO GPU optimal block size is 16, overriding currently set
{
config
.
block_size
}
"
# noqa: G004, E501
)
config
.
block_size
=
16
kv_cache_space
=
envs
.
VLLM_OPENVINO_KVCACHE_SPACE
if
kv_cache_space
>=
0
:
if
kv_cache_space
==
0
and
current_platform
.
is_openvino_cpu
():
config
.
openvino_kvcache_space_bytes
=
4
*
GiB_bytes
# type: ignore
logger
.
warning
(
"Environment variable VLLM_OPENVINO_KVCACHE_SPACE (GB) "
"for OpenVINO backend is not set, using 4 by default."
)
else
:
config
.
openvino_kvcache_space_bytes
=
kv_cache_space
*
GiB_bytes
# type: ignore
else
:
raise
RuntimeError
(
"Invalid environment variable VLLM_OPENVINO_KVCACHE_SPACE"
f
"
{
kv_cache_space
}
, expect a positive integer value."
)
return
config
vllm/platforms/openvino.py
View file @
86a44fb8
from
typing
import
TYPE_CHECKING
from
typing
import
TYPE_CHECKING
import
openvino
as
ov
import
openvino.properties.hint
as
hints
import
torch
import
torch
import
vllm.envs
as
envs
import
vllm.envs
as
envs
...
@@ -49,6 +51,8 @@ class OpenVinoPlatform(Platform):
...
@@ -49,6 +51,8 @@ class OpenVinoPlatform(Platform):
@
classmethod
@
classmethod
def
check_and_update_config
(
cls
,
vllm_config
:
VllmConfig
)
->
None
:
def
check_and_update_config
(
cls
,
vllm_config
:
VllmConfig
)
->
None
:
from
vllm.utils
import
GiB_bytes
parallel_config
=
vllm_config
.
parallel_config
parallel_config
=
vllm_config
.
parallel_config
assert
(
assert
(
parallel_config
.
world_size
==
1
parallel_config
.
world_size
==
1
...
@@ -57,3 +61,68 @@ class OpenVinoPlatform(Platform):
...
@@ -57,3 +61,68 @@ class OpenVinoPlatform(Platform):
if
parallel_config
.
worker_cls
==
"auto"
:
if
parallel_config
.
worker_cls
==
"auto"
:
parallel_config
.
worker_cls
=
\
parallel_config
.
worker_cls
=
\
"vllm.worker.openvino_worker.OpenVINOWorker"
"vllm.worker.openvino_worker.OpenVINOWorker"
# check and update model config
model_config
=
vllm_config
.
model_config
if
model_config
.
dtype
!=
torch
.
float32
:
logger
.
warning
(
f
"Only float32 dtype is supported on OpenVINO, casting from
{
model_config
.
dtype
}
."
# noqa: G004, E501
)
model_config
.
dtype
=
torch
.
float32
if
not
model_config
.
enforce_eager
:
logger
.
warning
(
"CUDA graph is not supported on OpenVINO backend, fallback to "
"the eager mode."
)
model_config
.
enforce_eager
=
True
# check and update cache config
ov_core
=
ov
.
Core
()
cache_config
=
vllm_config
.
cache_config
if
envs
.
VLLM_OPENVINO_CPU_KV_CACHE_PRECISION
==
"u8"
:
if
not
OpenVinoPlatform
.
is_openvino_cpu
():
logger
.
info
(
"VLLM_OPENVINO_CPU_KV_CACHE_PRECISION is"
"ignored for GPU, f16 data type will be used."
)
cache_config
.
cache_dtype
=
ov
.
Type
.
f16
else
:
logger
.
info
(
"KV cache type is overridden to u8 via "
"VLLM_OPENVINO_CPU_KV_CACHE_PRECISION env var."
)
cache_config
.
cache_dtype
=
ov
.
Type
.
u8
else
:
if
OpenVinoPlatform
.
is_openvino_cpu
():
ov_device
=
envs
.
VLLM_OPENVINO_DEVICE
inference_precision
=
ov_core
.
get_property
(
ov_device
,
hints
.
inference_precision
)
if
inference_precision
==
ov
.
Type
.
bf16
:
cache_config
.
cache_dtype
=
ov
.
Type
.
bf16
else
:
cache_config
.
cache_dtype
=
ov
.
Type
.
f16
else
:
cache_config
.
cache_dtype
=
ov
.
Type
.
f16
if
OpenVinoPlatform
.
is_openvino_cpu
():
if
cache_config
.
block_size
!=
32
:
logger
.
info
(
f
"OpenVINO CPU optimal block size is 32, overriding currently set
{
cache_config
.
block_size
}
"
# noqa: G004, E501
)
cache_config
.
block_size
=
32
else
:
if
cache_config
.
block_size
!=
16
:
logger
.
info
(
f
"OpenVINO GPU optimal block size is 16, overriding currently set
{
cache_config
.
block_size
}
"
# noqa: G004, E501
)
cache_config
.
block_size
=
16
kv_cache_space
=
envs
.
VLLM_OPENVINO_KVCACHE_SPACE
if
kv_cache_space
>=
0
:
if
kv_cache_space
==
0
and
OpenVinoPlatform
.
is_openvino_cpu
():
cache_config
.
openvino_kvcache_space_bytes
=
4
*
GiB_bytes
# type: ignore
logger
.
warning
(
"Environment variable VLLM_OPENVINO_KVCACHE_SPACE (GB) "
"for OpenVINO backend is not set, using 4 by default."
)
else
:
cache_config
.
openvino_kvcache_space_bytes
=
(
# type: ignore
kv_cache_space
*
GiB_bytes
)
else
:
raise
RuntimeError
(
"Invalid environment variable VLLM_OPENVINO_KVCACHE_SPACE"
f
"
{
kv_cache_space
}
, expect a positive integer value."
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment