Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c42590f9
Unverified
Commit
c42590f9
authored
Aug 21, 2024
by
Kunshang Ji
Committed by
GitHub
Aug 20, 2024
Browse files
[Hardware] [Intel GPU] refactor xpu worker/executor (#7686)
parent
aae6927b
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
26 additions
and
28 deletions
+26
-28
vllm/executor/xpu_executor.py
vllm/executor/xpu_executor.py
+15
-23
vllm/worker/xpu_model_runner.py
vllm/worker/xpu_model_runner.py
+0
-1
vllm/worker/xpu_worker.py
vllm/worker/xpu_worker.py
+11
-4
No files found.
vllm/executor/xpu_executor.py
View file @
c42590f9
from
typing
import
List
,
Optional
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
torch
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoadConfig
,
LoRAConfig
,
ModelConfig
,
ParallelConfig
,
PromptAdapterConfig
,
SchedulerConfig
,
SpeculativeConfig
)
ModelConfig
,
ObservabilityConfig
,
ParallelConfig
,
PromptAdapterConfig
,
SchedulerConfig
,
SpeculativeConfig
)
from
vllm.executor.executor_base
import
ExecutorAsyncBase
from
vllm.executor.gpu_executor
import
GPUExecutor
from
vllm.logger
import
init_logger
from
vllm.sequence
import
ExecuteModelRequest
,
SamplerOutput
from
vllm.sequence
import
ExecuteModelRequest
,
PoolerOutput
,
SamplerOutput
from
vllm.utils
import
make_async
from
vllm.worker.worker_base
import
WorkerWrapperBase
logger
=
init_logger
(
__name__
)
...
...
@@ -30,6 +30,7 @@ class XPUExecutor(GPUExecutor):
lora_config
:
Optional
[
LoRAConfig
],
prompt_adapter_config
:
Optional
[
PromptAdapterConfig
],
speculative_config
:
Optional
[
SpeculativeConfig
],
observability_config
:
Optional
[
ObservabilityConfig
],
)
->
None
:
assert
device_config
.
device_type
==
"xpu"
assert
(
not
speculative_config
...
...
@@ -46,32 +47,23 @@ class XPUExecutor(GPUExecutor):
self
.
device_config
=
device_config
self
.
prompt_adapter_config
=
prompt_adapter_config
self
.
speculative_config
=
None
self
.
observability_config
=
observability_config
# Instantiate the worker and load the model to GPU.
self
.
_init_executor
()
def
_create_worker
(
self
,
local_rank
:
int
=
0
,
rank
:
int
=
0
,
distributed_init_method
:
Optional
[
str
]
=
None
):
if
self
.
speculative_config
is
None
:
worker_module_name
=
"vllm.worker.xpu_worker"
worker_class_name
=
"XPUWorker"
else
:
def
_get_worker_module_and_class
(
self
)
->
Tuple
[
str
,
str
]:
if
self
.
speculative_config
is
not
None
:
raise
NotImplementedError
(
"XPU does not support speculative decoding"
)
wrapper
=
WorkerWrapperBase
(
worker_module_name
=
worker_module_name
,
worker_class_name
=
worker_class_name
,
)
wrapper
.
init_worker
(
**
self
.
_get_worker_kwargs
(
local_rank
,
rank
,
distributed_init_method
))
return
wrapper
.
worker
else
:
worker_module_name
=
"vllm.worker.xpu_worker"
worker_class_name
=
"XPUWorker"
return
(
worker_module_name
,
worker_class_name
)
def
execute_model
(
self
,
execute_model_req
:
ExecuteModelRequest
)
->
List
[
SamplerOutput
]:
self
,
execute_model_req
:
ExecuteModelRequest
)
->
Optional
[
List
[
Union
[
SamplerOutput
,
PoolerOutput
]]
]:
output
=
self
.
driver_worker
.
execute_model
(
execute_model_req
)
return
output
...
...
vllm/worker/xpu_model_runner.py
View file @
c42590f9
...
...
@@ -137,7 +137,6 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
device_config
=
self
.
device_config
,
load_config
=
self
.
load_config
,
lora_config
=
self
.
lora_config
,
multimodal_config
=
self
.
multimodal_config
,
parallel_config
=
self
.
parallel_config
,
scheduler_config
=
self
.
scheduler_config
,
cache_config
=
self
.
cache_config
,
...
...
vllm/worker/xpu_worker.py
View file @
c42590f9
...
...
@@ -9,8 +9,8 @@ import torch
import
torch.distributed
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoadConfig
,
LoRAConfig
,
ModelConfig
,
MultiModalConfig
,
Parallel
Config
,
PromptAdapterConfig
,
SchedulerConfig
,
ModelConfig
,
MultiModalConfig
,
Observability
Config
,
ParallelConfig
,
PromptAdapterConfig
,
SchedulerConfig
,
SpeculativeConfig
)
from
vllm.distributed
import
(
ensure_model_parallel_initialized
,
init_distributed_environment
)
...
...
@@ -50,6 +50,7 @@ class XPUWorker(LoraNotSupportedWorkerBase, Worker):
speculative_config
:
Optional
[
SpeculativeConfig
]
=
None
,
prompt_adapter_config
:
Optional
[
PromptAdapterConfig
]
=
None
,
is_driver_worker
:
bool
=
False
,
observability_config
:
Optional
[
ObservabilityConfig
]
=
None
,
)
->
None
:
assert
device_config
.
device_type
==
"xpu"
assert
is_xpu
()
...
...
@@ -67,8 +68,10 @@ class XPUWorker(LoraNotSupportedWorkerBase, Worker):
self
.
lora_config
=
lora_config
self
.
prompt_adapter_config
=
prompt_adapter_config
self
.
is_driver_worker
=
is_driver_worker
if
self
.
is_driver_worker
:
assert
self
.
rank
==
0
,
"The driver worker must have rank 0."
self
.
observability_config
=
observability_config
if
parallel_config
and
is_driver_worker
:
assert
rank
%
parallel_config
.
tensor_parallel_size
==
0
,
\
"Driver worker should be rank 0 of tensor parallel group."
self
.
multimodal_config
=
multimodal_config
...
...
@@ -183,7 +186,11 @@ class XPUWorker(LoraNotSupportedWorkerBase, Worker):
# dependency (libdrm and drm headers) on your system.
ENV_CCL_ZE_IPC_EXCHANGE
=
os
.
getenv
(
"CCL_ZE_IPC_EXCHANGE"
,
"sockets"
)
ENV_LOCAL_WORLD_SIZE
=
os
.
getenv
(
"LOCAL_WORLD_SIZE"
,
str
(
parallel_config
.
world_size
))
os
.
environ
[
'CCL_ZE_IPC_EXCHANGE'
]
=
ENV_CCL_ZE_IPC_EXCHANGE
os
.
environ
[
"LOCAL_WORLD_SIZE"
]
=
ENV_LOCAL_WORLD_SIZE
os
.
environ
[
"LOCAL_RANK"
]
=
str
(
self
.
local_rank
)
init_distributed_environment
(
world_size
=
parallel_config
.
world_size
,
rank
=
rank
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment