Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
66a9e713
Unverified
Commit
66a9e713
authored
Aug 20, 2024
by
Antoni Baum
Committed by
GitHub
Aug 21, 2024
Browse files
[Core] Pipe `worker_class_fn` argument in Executor (#7707)
parent
9e51b6a6
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
26 additions
and
14 deletions
+26
-14
vllm/executor/gpu_executor.py
vllm/executor/gpu_executor.py
+17
-9
vllm/executor/ray_gpu_executor.py
vllm/executor/ray_gpu_executor.py
+3
-2
vllm/executor/xpu_executor.py
vllm/executor/xpu_executor.py
+6
-3
No files found.
vllm/executor/gpu_executor.py
View file @
66a9e713
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Set
,
Tuple
,
Union
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Set
,
Tuple
,
Type
,
Union
from
vllm.executor.executor_base
import
ExecutorAsyncBase
,
ExecutorBase
from
vllm.logger
import
init_logger
...
...
@@ -7,15 +7,18 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
from
vllm.sequence
import
ExecuteModelRequest
,
PoolerOutput
,
SamplerOutput
from
vllm.utils
import
(
get_distributed_init_method
,
get_ip
,
get_open_port
,
make_async
)
from
vllm.worker.worker_base
import
WorkerWrapperBase
from
vllm.worker.worker_base
import
WorkerBase
,
WorkerWrapperBase
logger
=
init_logger
(
__name__
)
def
create_worker
(
worker_module_name
,
worker_class_name
,
**
kwargs
):
def
create_worker
(
worker_module_name
:
str
,
worker_class_name
:
str
,
worker_class_fn
:
Optional
[
Callable
[[],
Type
[
WorkerBase
]]],
**
kwargs
):
wrapper
=
WorkerWrapperBase
(
worker_module_name
=
worker_module_name
,
worker_class_name
=
worker_class_name
,
worker_class_fn
=
worker_class_fn
,
)
wrapper
.
init_worker
(
**
kwargs
)
return
wrapper
.
worker
...
...
@@ -62,7 +65,9 @@ class GPUExecutor(ExecutorBase):
observability_config
=
self
.
observability_config
,
)
def
_get_worker_module_and_class
(
self
)
->
Tuple
[
str
,
str
]:
def
_get_worker_module_and_class
(
self
)
->
Tuple
[
str
,
str
,
Optional
[
Callable
[[],
Type
[
WorkerBase
]]]]:
worker_class_fn
=
None
if
self
.
scheduler_config
.
is_multi_step
:
worker_module_name
=
"vllm.worker.multi_step_worker"
worker_class_name
=
"MultiStepWorker"
...
...
@@ -72,7 +77,7 @@ class GPUExecutor(ExecutorBase):
else
:
worker_module_name
=
"vllm.worker.worker"
worker_class_name
=
"Worker"
return
(
worker_module_name
,
worker_class_name
)
return
(
worker_module_name
,
worker_class_name
,
worker_class_fn
)
def
_get_create_worker_kwargs
(
self
,
...
...
@@ -82,10 +87,13 @@ class GPUExecutor(ExecutorBase):
worker_kwargs
=
self
.
_get_worker_kwargs
(
local_rank
,
rank
,
distributed_init_method
)
(
worker_module_name
,
worker_class_name
)
=
self
.
_get_worker_module_and_class
()
worker_kwargs
.
update
(
worker_module_name
=
worker_module_name
,
worker_class_name
=
worker_class_name
)
(
worker_module_name
,
worker_class_name
,
worker_class_fn
)
=
self
.
_get_worker_module_and_class
()
worker_kwargs
.
update
(
worker_module_name
=
worker_module_name
,
worker_class_name
=
worker_class_name
,
worker_class_fn
=
worker_class_fn
,
)
return
worker_kwargs
...
...
vllm/executor/ray_gpu_executor.py
View file @
66a9e713
...
...
@@ -91,12 +91,13 @@ class RayGPUExecutor(DistributedGPUExecutor):
return
ray_remote_kwargs
def
_get_worker_wrapper_args
(
self
)
->
Dict
[
str
,
Any
]:
(
worker_module_name
,
worker_class_n
ame
)
=
self
.
_get_worker_module_and_class
()
(
worker_module_name
,
worker_class_name
,
worker_class_
f
n
)
=
self
.
_get_worker_module_and_class
()
return
dict
(
worker_module_name
=
worker_module_name
,
worker_class_name
=
worker_class_name
,
worker_class_fn
=
worker_class_fn
,
trust_remote_code
=
self
.
model_config
.
trust_remote_code
,
)
...
...
vllm/executor/xpu_executor.py
View file @
66a9e713
from
typing
import
List
,
Optional
,
Tuple
,
Union
from
typing
import
Callable
,
List
,
Optional
,
Tuple
,
Type
,
Union
import
torch
...
...
@@ -11,6 +11,7 @@ from vllm.executor.gpu_executor import GPUExecutor
from
vllm.logger
import
init_logger
from
vllm.sequence
import
ExecuteModelRequest
,
PoolerOutput
,
SamplerOutput
from
vllm.utils
import
make_async
from
vllm.worker.worker_base
import
WorkerBase
logger
=
init_logger
(
__name__
)
...
...
@@ -52,14 +53,16 @@ class XPUExecutor(GPUExecutor):
# Instantiate the worker and load the model to GPU.
self
.
_init_executor
()
def
_get_worker_module_and_class
(
self
)
->
Tuple
[
str
,
str
]:
def
_get_worker_module_and_class
(
self
)
->
Tuple
[
str
,
str
,
Optional
[
Callable
[[],
Type
[
WorkerBase
]]]]:
worker_class_fn
=
None
if
self
.
speculative_config
is
not
None
:
raise
NotImplementedError
(
"XPU does not support speculative decoding"
)
else
:
worker_module_name
=
"vllm.worker.xpu_worker"
worker_class_name
=
"XPUWorker"
return
(
worker_module_name
,
worker_class_name
)
return
(
worker_module_name
,
worker_class_name
,
worker_class_fn
)
def
execute_model
(
self
,
execute_model_req
:
ExecuteModelRequest
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment