Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
647214f3
Unverified
Commit
647214f3
authored
Oct 21, 2025
by
Nick Hill
Committed by
GitHub
Oct 21, 2025
Browse files
[V0 Deprecation] Remove V0 executors (#27142)
Signed-off-by:
Nick Hill
<
nhill@redhat.com
>
parent
ddeec11b
Changes
31
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
387 additions
and
532 deletions
+387
-532
vllm/v1/engine/core_client.py
vllm/v1/engine/core_client.py
+1
-1
vllm/v1/engine/llm_engine.py
vllm/v1/engine/llm_engine.py
+1
-1
vllm/v1/engine/utils.py
vllm/v1/engine/utils.py
+1
-1
vllm/v1/executor/__init__.py
vllm/v1/executor/__init__.py
+6
-0
vllm/v1/executor/abstract.py
vllm/v1/executor/abstract.py
+229
-40
vllm/v1/executor/multiproc_executor.py
vllm/v1/executor/multiproc_executor.py
+2
-1
vllm/v1/executor/ray_distributed_executor.py
vllm/v1/executor/ray_distributed_executor.py
+4
-107
vllm/v1/executor/ray_executor.py
vllm/v1/executor/ray_executor.py
+104
-280
vllm/v1/executor/ray_utils.py
vllm/v1/executor/ray_utils.py
+28
-50
vllm/v1/executor/uniproc_executor.py
vllm/v1/executor/uniproc_executor.py
+11
-29
vllm/v1/worker/worker_base.py
vllm/v1/worker/worker_base.py
+0
-22
No files found.
vllm/v1/engine/core_client.py
View file @
647214f3
...
...
@@ -46,7 +46,7 @@ from vllm.v1.engine.utils import (
CoreEngineProcManager
,
launch_core_engines
,
)
from
vllm.v1.executor
.abstract
import
Executor
from
vllm.v1.executor
import
Executor
from
vllm.v1.serial_utils
import
MsgpackDecoder
,
MsgpackEncoder
,
bytestr
logger
=
init_logger
(
__name__
)
...
...
vllm/v1/engine/llm_engine.py
View file @
647214f3
...
...
@@ -32,7 +32,7 @@ from vllm.v1.engine.core_client import EngineCoreClient
from
vllm.v1.engine.output_processor
import
OutputProcessor
from
vllm.v1.engine.parallel_sampling
import
ParentRequest
from
vllm.v1.engine.processor
import
Processor
from
vllm.v1.executor
.abstract
import
Executor
from
vllm.v1.executor
import
Executor
from
vllm.v1.metrics.loggers
import
StatLoggerFactory
,
StatLoggerManager
from
vllm.v1.metrics.reader
import
Metric
,
get_metrics_snapshot
from
vllm.v1.metrics.stats
import
IterationStats
...
...
vllm/v1/engine/utils.py
View file @
647214f3
...
...
@@ -23,7 +23,7 @@ from vllm.ray.ray_env import get_env_vars_to_copy
from
vllm.utils
import
get_mp_context
from
vllm.utils.network_utils
import
get_open_zmq_ipc_path
,
zmq_socket_ctx
from
vllm.v1.engine.coordinator
import
DPCoordinator
from
vllm.v1.executor
.abstract
import
Executor
from
vllm.v1.executor
import
Executor
from
vllm.v1.utils
import
get_engine_client_zmq_addr
,
shutdown
if
TYPE_CHECKING
:
...
...
vllm/v1/executor/__init__.py
View file @
647214f3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
.abstract
import
Executor
from
.uniproc_executor
import
UniProcExecutor
__all__
=
[
"Executor"
,
"UniProcExecutor"
]
vllm/v1/executor/abstract.py
View file @
647214f3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
time
from
abc
import
ABC
,
abstractmethod
from
collections.abc
import
Callable
from
concurrent.futures
import
Future
from
typing
import
Any
import
torch
import
torch.distributed
as
dist
from
functools
import
cached_property
from
typing
import
Literal
,
TypeVar
,
overload
from
vllm.config
import
VllmConfig
from
vllm.executor.executor_base
import
ExecutorBase
from
vllm.executor.uniproc_executor
import
(
# noqa
ExecutorWithExternalLauncher
as
ExecutorWithExternalLauncherV0
,
)
from
vllm.executor.uniproc_executor
import
UniProcExecutor
as
UniProcExecutorV0
# noqa
from
vllm.distributed.kv_transfer.kv_connector.utils
import
KVOutputAggregator
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.tasks
import
SupportedTask
from
vllm.utils.import_utils
import
resolve_obj_by_qualname
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.engine
import
ReconfigureDistributedRequest
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
,
KVCacheSpec
from
vllm.v1.outputs
import
DraftTokenIds
,
ModelRunnerOutput
from
vllm.v1.worker.worker_base
import
WorkerBase
logger
=
init_logger
(
__name__
)
_R
=
TypeVar
(
"_R"
)
FailureCallback
=
Callable
[[],
None
]
class
Executor
(
ExecutorBase
):
class
Executor
(
ABC
):
"""Abstract base class for vLLM executors."
An executor is responsible for executing the model on one device,
or it can be a distributed executor that can execute the model on multiple devices.
"""
Abstract class for v1 executors, mainly define some methods for v1.
For methods shared by v0 and v1, define them in ExecutorBase"""
uses_ray
:
bool
=
False
# whether the executor uses Ray for orchestration.
supports_pp
:
bool
=
False
# whether the executor supports PP
@
staticmethod
def
get_class
(
vllm_config
:
VllmConfig
)
->
type
[
"Executor"
]:
...
...
@@ -34,16 +43,14 @@ class Executor(ExecutorBase):
distributed_executor_backend
=
parallel_config
.
distributed_executor_backend
# distributed_executor_backend must be set in VllmConfig.__post_init__
if
isinstance
(
distributed_executor_backend
,
type
):
if
not
issubclass
(
distributed_executor_backend
,
Executor
Base
):
if
not
issubclass
(
distributed_executor_backend
,
Executor
):
raise
TypeError
(
"distributed_executor_backend must be a subclass of "
f
"Executor
Base
. Got
{
distributed_executor_backend
}
."
f
"Executor. Got
{
distributed_executor_backend
}
."
)
executor_class
=
distributed_executor_backend
elif
distributed_executor_backend
==
"ray"
:
from
vllm.v1.executor.ray_distributed_executor
import
(
# noqa
RayDistributedExecutor
,
)
from
vllm.v1.executor.ray_executor
import
RayDistributedExecutor
executor_class
=
RayDistributedExecutor
elif
distributed_executor_backend
==
"mp"
:
...
...
@@ -51,6 +58,8 @@ class Executor(ExecutorBase):
executor_class
=
MultiprocExecutor
elif
distributed_executor_backend
==
"uni"
:
from
vllm.v1.executor.uniproc_executor
import
UniProcExecutor
executor_class
=
UniProcExecutor
elif
distributed_executor_backend
==
"external_launcher"
:
# TODO: make v1 scheduling deterministic
...
...
@@ -58,10 +67,10 @@ class Executor(ExecutorBase):
executor_class
=
ExecutorWithExternalLauncher
elif
isinstance
(
distributed_executor_backend
,
str
):
executor_class
=
resolve_obj_by_qualname
(
distributed_executor_backend
)
if
not
issubclass
(
executor_class
,
Executor
Base
):
if
not
issubclass
(
executor_class
,
Executor
):
raise
TypeError
(
"distributed_executor_backend must be a subclass of "
f
"Executor
Base
. Got
{
executor_class
}
."
f
"Executor. Got
{
executor_class
}
."
)
else
:
raise
ValueError
(
...
...
@@ -69,6 +78,29 @@ class Executor(ExecutorBase):
)
return
executor_class
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
)
->
None
:
self
.
vllm_config
=
vllm_config
self
.
model_config
=
vllm_config
.
model_config
self
.
cache_config
=
vllm_config
.
cache_config
self
.
lora_config
=
vllm_config
.
lora_config
self
.
load_config
=
vllm_config
.
load_config
self
.
parallel_config
=
vllm_config
.
parallel_config
self
.
scheduler_config
=
vllm_config
.
scheduler_config
self
.
device_config
=
vllm_config
.
device_config
self
.
speculative_config
=
vllm_config
.
speculative_config
self
.
observability_config
=
vllm_config
.
observability_config
self
.
_init_executor
()
self
.
is_sleeping
=
False
self
.
sleeping_tags
:
set
[
str
]
=
set
()
self
.
kv_output_aggregator
:
KVOutputAggregator
|
None
=
None
@
abstractmethod
def
_init_executor
(
self
)
->
None
:
raise
NotImplementedError
def
initialize_from_config
(
self
,
kv_cache_configs
:
list
[
KVCacheConfig
])
->
None
:
"""
Initialize the KV caches and begin the model execution loop of the
...
...
@@ -77,7 +109,7 @@ class Executor(ExecutorBase):
self
.
collective_rpc
(
"initialize_from_config"
,
args
=
(
kv_cache_configs
,))
self
.
collective_rpc
(
"compile_or_warm_up_model"
)
def
register_failure_callback
(
self
,
callback
:
FailureCallback
):
def
register_failure_callback
(
self
,
callback
:
FailureCallback
):
# noqa: B027
"""
Register a function to be called if the executor enters a permanent
failed state.
...
...
@@ -90,22 +122,78 @@ class Executor(ExecutorBase):
def
get_kv_cache_specs
(
self
)
->
list
[
dict
[
str
,
KVCacheSpec
]]:
return
self
.
collective_rpc
(
"get_kv_cache_spec"
)
@
overload
def
collective_rpc
(
self
,
method
:
str
|
Callable
[[
WorkerBase
],
_R
],
timeout
:
float
|
None
=
None
,
args
:
tuple
=
(),
kwargs
:
dict
|
None
=
None
,
non_block
:
Literal
[
False
]
=
False
,
)
->
list
[
_R
]:
"""
Execute an RPC call on all workers.
Args:
method: Name of the worker method to execute, or a callable that
is serialized and sent to all workers to execute.
If the method is a callable, it should accept an additional
`self` argument, in addition to the arguments passed in `args`
and `kwargs`. The `self` argument will be the worker object.
timeout: Maximum time in seconds to wait for execution. Raises a
[`TimeoutError`][] on timeout. `None` means wait indefinitely.
args: Positional arguments to pass to the worker method.
kwargs: Keyword arguments to pass to the worker method.
non_block: If `True`, returns a list of Futures instead of waiting
for the results.
Returns:
A list containing the results from each worker.
Note:
It is recommended to use this API to only pass control messages,
and set up data-plane communication to pass data.
"""
pass
@
overload
def
collective_rpc
(
self
,
method
:
str
|
Callable
,
method
:
str
|
Callable
[[
WorkerBase
],
_R
]
,
timeout
:
float
|
None
=
None
,
args
:
tuple
=
(),
kwargs
:
dict
|
None
=
None
,
non_block
:
bool
=
False
,
)
->
list
[
Any
]:
non_block
:
Literal
[
True
]
=
True
,
)
->
list
[
Future
[
_R
]]:
pass
@
abstractmethod
def
collective_rpc
(
self
,
method
,
timeout
=
None
,
args
=
(),
kwargs
=
None
,
non_block
:
bool
=
False
):
raise
NotImplementedError
@
overload
def
execute_model
(
self
,
scheduler_output
:
SchedulerOutput
,
non_block
:
bool
=
False
,
non_block
:
Literal
[
False
]
=
False
,
)
->
ModelRunnerOutput
:
pass
@
overload
def
execute_model
(
self
,
scheduler_output
:
SchedulerOutput
,
non_block
:
Literal
[
True
]
=
True
,
)
->
Future
[
ModelRunnerOutput
]:
pass
def
execute_model
(
self
,
scheduler_output
:
SchedulerOutput
,
non_block
:
bool
=
False
)
->
ModelRunnerOutput
|
Future
[
ModelRunnerOutput
]:
output
=
self
.
collective_rpc
(
output
=
self
.
collective_rpc
(
# type: ignore[call-overload]
"execute_model"
,
args
=
(
scheduler_output
,),
non_block
=
non_block
)
return
output
[
0
]
...
...
@@ -114,7 +202,7 @@ class Executor(ExecutorBase):
self
.
collective_rpc
(
"execute_dummy_batch"
)
def
take_draft_token_ids
(
self
)
->
DraftTokenIds
|
None
:
output
=
self
.
collective_rpc
(
"take_draft_token_ids"
)
output
:
list
[
DraftTokenIds
]
=
self
.
collective_rpc
(
"take_draft_token_ids"
)
return
output
[
0
]
@
property
...
...
@@ -124,19 +212,120 @@ class Executor(ExecutorBase):
def
profile
(
self
,
is_start
:
bool
=
True
):
self
.
collective_rpc
(
"profile"
,
args
=
(
is_start
,))
def
save_sharded_state
(
self
,
path
:
str
,
pattern
:
str
|
None
=
None
,
max_size
:
int
|
None
=
None
,
)
->
None
:
self
.
collective_rpc
(
"save_sharded_state"
,
kwargs
=
dict
(
path
=
path
,
pattern
=
pattern
,
max_size
=
max_size
),
)
@
abstractmethod
def
check_health
(
self
)
->
None
:
"""Checks if the executor is healthy. If not, it should raise an
exception."""
raise
NotImplementedError
class
UniProcExecutor
(
UniProcExecutorV0
,
Executor
):
pass
def
shutdown
(
self
)
->
None
:
"""Shutdown the executor."""
self
.
collective_rpc
(
"shutdown"
)
def
init_kv_output_aggregator
(
self
,
finished_count
:
int
|
None
)
->
None
:
"""Init KVOutputAggregator"""
self
.
kv_output_aggregator
=
KVOutputAggregator
(
finished_count
or
self
.
parallel_config
.
world_size
)
class
ExecutorWithExternalLauncher
(
ExecutorWithExternalLauncherV0
,
Executor
):
def
determine_available_memory
(
self
)
->
list
[
int
]:
# in bytes
# same as determine_num_available_blocks in v0,
# we need to get the min across all ranks.
memory
=
super
().
determine_available_memory
()
from
vllm.distributed.parallel_state
import
get_world_group
cpu_group
=
get_world_group
().
cpu_group
memory_tensor
=
torch
.
tensor
([
memory
],
device
=
"cpu"
,
dtype
=
torch
.
int64
)
dist
.
all_reduce
(
memory_tensor
,
group
=
cpu_group
,
op
=
dist
.
ReduceOp
.
MIN
)
return
[
memory_tensor
.
item
()]
@
cached_property
# Avoid unnecessary RPC calls
def
supported_tasks
(
self
)
->
tuple
[
SupportedTask
,
...]:
output
:
list
[
tuple
[
SupportedTask
,
...]]
output
=
self
.
collective_rpc
(
"get_supported_tasks"
)
return
output
[
0
]
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
bool
:
assert
lora_request
.
lora_int_id
>
0
,
"lora_id must be greater than 0."
return
all
(
self
.
collective_rpc
(
"add_lora"
,
args
=
(
lora_request
,)))
def
remove_lora
(
self
,
lora_id
:
int
)
->
bool
:
assert
lora_id
>
0
,
"lora_id must be greater than 0."
return
all
(
self
.
collective_rpc
(
"remove_lora"
,
args
=
(
lora_id
,)))
def
pin_lora
(
self
,
lora_id
:
int
)
->
bool
:
assert
lora_id
>
0
,
"lora_id must be greater than 0."
return
all
(
self
.
collective_rpc
(
"pin_lora"
,
args
=
(
lora_id
,)))
def
list_loras
(
self
)
->
set
[
int
]:
sets
:
list
[
set
[
int
]]
=
self
.
collective_rpc
(
"list_loras"
)
for
s
in
sets
:
assert
s
==
sets
[
0
],
"All workers should have the same LORAs."
return
sets
[
0
]
def
reset_mm_cache
(
self
)
->
None
:
"""Reset the multi-modal cache in each worker."""
self
.
collective_rpc
(
"reset_mm_cache"
)
def
start_profile
(
self
)
->
None
:
self
.
collective_rpc
(
"start_profile"
)
def
stop_profile
(
self
)
->
None
:
self
.
collective_rpc
(
"stop_profile"
)
def
sleep
(
self
,
level
:
int
=
1
):
if
self
.
is_sleeping
:
logger
.
warning
(
"Executor is already sleeping."
)
return
time_before_sleep
=
time
.
perf_counter
()
self
.
collective_rpc
(
"sleep"
,
kwargs
=
dict
(
level
=
level
))
time_after_sleep
=
time
.
perf_counter
()
self
.
sleeping_tags
=
{
"weights"
,
"kv_cache"
}
self
.
is_sleeping
=
True
logger
.
info
(
"It took %.6f seconds to fall asleep."
,
time_after_sleep
-
time_before_sleep
)
def
wake_up
(
self
,
tags
:
list
[
str
]
|
None
=
None
):
if
not
self
.
is_sleeping
:
logger
.
warning
(
"Executor is not sleeping."
)
return
if
tags
:
for
tag
in
tags
:
if
tag
not
in
self
.
sleeping_tags
:
logger
.
warning
(
"Tag %s is not in sleeping tags %s"
,
tag
,
self
.
sleeping_tags
)
return
time_before_wakeup
=
time
.
perf_counter
()
self
.
collective_rpc
(
"wake_up"
,
kwargs
=
dict
(
tags
=
tags
))
time_after_wakeup
=
time
.
perf_counter
()
logger
.
info
(
"It took %.6f seconds to wake up tags %s."
,
time_after_wakeup
-
time_before_wakeup
,
tags
if
tags
is
not
None
else
self
.
sleeping_tags
,
)
if
tags
:
for
tag
in
tags
:
self
.
sleeping_tags
.
remove
(
tag
)
else
:
self
.
sleeping_tags
.
clear
()
if
not
self
.
sleeping_tags
:
self
.
is_sleeping
=
False
def
reinitialize_distributed
(
self
,
reconfig_request
:
ReconfigureDistributedRequest
)
->
None
:
raise
NotImplementedError
from
vllm.v1.executor.uniproc_executor
import
(
# noqa: E402
ExecutorWithExternalLauncher
as
_ExecutorWithExternalLauncher
,
)
from
vllm.v1.executor.uniproc_executor
import
(
# noqa: E402
UniProcExecutor
as
_UniProcExecutor
,
)
# For backwards compatibility.
UniProcExecutor
=
_UniProcExecutor
ExecutorWithExternalLauncher
=
_ExecutorWithExternalLauncher
vllm/v1/executor/multiproc_executor.py
View file @
647214f3
...
...
@@ -179,7 +179,7 @@ class MultiprocExecutor(Executor):
else
:
self
.
failure_callback
=
callback
def
execute_model
(
def
execute_model
(
# type: ignore[override]
self
,
scheduler_output
:
SchedulerOutput
,
non_block
:
bool
=
False
,
...
...
@@ -204,6 +204,7 @@ class MultiprocExecutor(Executor):
)
# aggregate all workers output to a single output
assert
self
.
kv_output_aggregator
is
not
None
if
non_block
:
return
self
.
kv_output_aggregator
.
async_aggregate
(
outputs
,
self
.
output_rank
)
return
self
.
kv_output_aggregator
.
aggregate
(
outputs
,
self
.
output_rank
)
...
...
vllm/v1/executor/ray_distributed_executor.py
View file @
647214f3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
concurrent.futures
import
Future
from
vllm.distributed.kv_transfer.kv_connector.utils
import
KVOutputAggregator
from
vllm.executor.ray_distributed_executor
import
(
# noqa
RayDistributedExecutor
as
RayDistributedExecutorV0
,
from
vllm.v1.executor.ray_executor
import
(
RayDistributedExecutor
as
_RayDistributedExecutor
,
)
from
vllm.logger
import
init_logger
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.engine
import
ReconfigureDistributedRequest
,
ReconfigureRankType
from
vllm.v1.executor.abstract
import
Executor
from
vllm.v1.outputs
import
ModelRunnerOutput
logger
=
init_logger
(
__name__
)
class
FutureWrapper
(
Future
):
"""A wrapper around Ray output reference to meet the interface
of .execute_model(): The top level (core busy loop) expects .result() api
to block and return a single output.
If aggregator is provided, the outputs from all workers are aggregated upon
the result() call. If not only the first worker's output is returned.
"""
def
__init__
(
self
,
refs
,
aggregator
:
KVOutputAggregator
|
None
=
None
):
super
().
__init__
()
self
.
refs
=
refs
self
.
aggregator
=
aggregator
def
result
(
self
,
timeout
=
None
):
if
timeout
is
not
None
:
raise
NotImplementedError
(
"timeout is not supported"
)
if
self
.
aggregator
is
None
:
return
self
.
refs
[
0
].
get
()
outputs
=
[
ref
.
get
()
for
ref
in
self
.
refs
]
return
self
.
aggregator
.
aggregate
(
outputs
,
output_rank
=
0
)
class
RayDistributedExecutor
(
RayDistributedExecutorV0
,
Executor
):
"""Ray distributed executor using Ray Compiled Graphs."""
supports_pp
:
bool
=
True
def
_init_executor
(
self
)
->
None
:
super
().
_init_executor
()
# KV connector setup
self
.
has_connector
=
self
.
vllm_config
.
kv_transfer_config
is
not
None
@
property
def
max_concurrent_batches
(
self
)
->
int
:
"""Ray distributed executor supports pipeline parallelism,
meaning that it allows PP size batches to be executed concurrently.
"""
if
self
.
scheduler_config
.
async_scheduling
:
return
2
return
self
.
parallel_config
.
pipeline_parallel_size
def
execute_model
(
self
,
scheduler_output
:
SchedulerOutput
,
non_block
:
bool
=
False
,
)
->
ModelRunnerOutput
|
Future
[
ModelRunnerOutput
]:
"""Execute the model on the Ray workers.
Args:
scheduler_output: The scheduler output to execute.
non_block: If True, the method will return a Future.
Returns:
The model runner output.
"""
# Build the compiled DAG for the first time.
if
self
.
forward_dag
is
None
:
# type: ignore
self
.
forward_dag
=
self
.
_compiled_ray_dag
(
enable_asyncio
=
False
)
refs
=
self
.
forward_dag
.
execute
(
scheduler_output
)
# type: ignore
if
not
self
.
has_connector
:
# Get output only from a single worker (output_rank)
# When PP is not used, we block here until the result is available.
if
not
non_block
:
return
refs
[
0
].
get
()
# When PP is used, we return a FutureWrapper immediately so that
# the scheduler can yield to the next batch.
return
FutureWrapper
(
refs
)
# Get output from all workers when connector is present
if
not
non_block
:
# Block and get results from all workers
outputs
=
[
ref
.
get
()
for
ref
in
refs
]
return
self
.
kv_output_aggregator
.
aggregate
(
outputs
)
# Return a future that will aggregate outputs from all workers
return
FutureWrapper
(
refs
,
self
.
kv_output_aggregator
)
def
reinitialize_distributed
(
self
,
reconfig_request
:
ReconfigureDistributedRequest
)
->
None
:
self
.
_run_workers
(
"reinitialize_distributed"
,
reconfig_request
)
if
(
reconfig_request
.
new_data_parallel_rank
==
ReconfigureRankType
.
SHUTDOWN_CURRENT_RANK
):
self
.
shutdown
()
# For backwards compatibility.
RayDistributedExecutor
=
_RayDistributedExecutor
vllm/executor/ray_
distributed_
executor.py
→
vllm/
v1/
executor/ray_executor.py
View file @
647214f3
This diff is collapsed.
Click to expand it.
vllm/executor/ray_utils.py
→
vllm/
v1/
executor/ray_utils.py
View file @
647214f3
...
...
@@ -4,17 +4,16 @@
import
os
import
time
from
collections
import
defaultdict
from
concurrent.futures
import
Future
from
typing
import
TYPE_CHECKING
,
Union
import
msgspec
import
vllm.platforms
from
vllm.config
import
ParallelConfig
from
vllm.distributed
import
get_pp_group
from
vllm.
ex
ec
u
tor.
msgspec_
utils
import
decode_hook
,
encode_hook
from
vllm.
distributed.kv_transfer.kv_conn
ector.utils
import
KVOutputAggregator
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.sequence
import
ExecuteModelRequest
,
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils.network_utils
import
get_ip
from
vllm.v1.outputs
import
AsyncModelRunnerOutput
from
vllm.v1.worker.worker_base
import
WorkerWrapperBase
...
...
@@ -51,11 +50,6 @@ try:
# that thread.
self
.
compiled_dag_cuda_device_set
=
False
self
.
input_decoder
=
msgspec
.
msgpack
.
Decoder
(
ExecuteModelRequest
,
dec_hook
=
decode_hook
)
self
.
output_encoder
=
msgspec
.
msgpack
.
Encoder
(
enc_hook
=
encode_hook
)
def
get_node_ip
(
self
)
->
str
:
return
get_ip
()
...
...
@@ -70,47 +64,6 @@ try:
gpu_ids
=
ray
.
get_runtime_context
().
get_accelerator_ids
()[
device_key
]
return
node_id
,
gpu_ids
def
execute_model_spmd
(
self
,
req_or_tuple
:
bytes
|
tuple
[
bytes
,
IntermediateTensors
|
None
],
)
->
bytes
:
"""Execute model in SPMD fashion: used only when SPMD worker and
compiled DAG are both enabled.
Args:
req_or_tuple: A request or a tuple containing the
request and intermediate tensors. Intermediate tensors are
None unless if it is provided because it is > 0 pipeline
stage. The request is serialized by msgspec.
"""
if
isinstance
(
req_or_tuple
,
bytes
):
serialized_req
,
intermediate_tensors
=
req_or_tuple
,
None
else
:
serialized_req
,
intermediate_tensors
=
req_or_tuple
execute_model_req
=
self
.
input_decoder
.
decode
(
serialized_req
)
assert
self
.
worker
is
not
None
,
"Worker is not initialized"
# TODO(swang): This is needed right now because Ray Compiled Graph
# executes on a background thread, so we need to reset torch's
# current device.
if
not
self
.
compiled_dag_cuda_device_set
:
assert
self
.
worker
.
device
is
not
None
current_platform
.
set_device
(
self
.
worker
.
device
)
self
.
compiled_dag_cuda_device_set
=
True
output
=
self
.
worker
.
_execute_model_spmd
(
# type: ignore[attr-defined]
execute_model_req
,
intermediate_tensors
)
# Pipeline model request and output to the next pipeline stage.
if
isinstance
(
output
,
IntermediateTensors
):
output
=
serialized_req
,
output
else
:
output
=
self
.
output_encoder
.
encode
(
output
)
return
output
def
setup_device_if_necessary
(
self
):
# TODO(swang): This is needed right now because Ray CG executes
# on a background thread, so we need to reset torch's current
...
...
@@ -174,6 +127,31 @@ except ImportError as e:
RayWorkerWrapper
=
None
# type: ignore
class
FutureWrapper
(
Future
):
"""A wrapper around Ray output reference to meet the interface
of .execute_model(): The top level (core busy loop) expects .result() api
to block and return a single output.
If aggregator is provided, the outputs from all workers are aggregated upon
the result() call. If not only the first worker's output is returned.
"""
def
__init__
(
self
,
refs
,
aggregator
:
KVOutputAggregator
|
None
=
None
):
super
().
__init__
()
self
.
refs
=
refs
self
.
aggregator
=
aggregator
def
result
(
self
,
timeout
=
None
):
if
timeout
is
not
None
:
raise
NotImplementedError
(
"timeout is not supported"
)
if
self
.
aggregator
is
None
:
return
self
.
refs
[
0
].
get
()
outputs
=
[
ref
.
get
()
for
ref
in
self
.
refs
]
return
self
.
aggregator
.
aggregate
(
outputs
,
output_rank
=
0
)
def
ray_is_available
()
->
bool
:
"""Returns True if Ray is available."""
return
ray
is
not
None
...
...
vllm/executor/uniproc_executor.py
→
vllm/
v1/
executor/uniproc_executor.py
View file @
647214f3
...
...
@@ -11,20 +11,18 @@ import torch
import
torch.distributed
as
dist
import
vllm.envs
as
envs
from
vllm.executor.executor_base
import
ExecutorBase
from
vllm.logger
import
init_logger
from
vllm.utils
import
run_method
from
vllm.utils.network_utils
import
get_distributed_init_method
,
get_ip
,
get_open_port
from
vllm.v1.engine
import
ReconfigureDistributedRequest
,
ReconfigureRankType
from
vllm.v1.executor.abstract
import
Executor
from
vllm.v1.outputs
import
AsyncModelRunnerOutput
from
vllm.v1.worker.worker_base
import
WorkerWrapperBase
logger
=
init_logger
(
__name__
)
class
UniProcExecutor
(
ExecutorBase
):
uses_ray
:
bool
=
False
class
UniProcExecutor
(
Executor
):
def
_init_executor
(
self
)
->
None
:
"""Initialize the worker and load the model."""
self
.
driver_worker
=
WorkerWrapperBase
(
vllm_config
=
self
.
vllm_config
,
rpc_rank
=
0
)
...
...
@@ -44,9 +42,9 @@ class UniProcExecutor(ExecutorBase):
max_workers
=
1
,
thread_name_prefix
=
"WorkerAsyncOutput"
)
self
.
collective_rpc
(
"
init_worker
"
,
args
=
(
[
kwargs
]
,)
)
self
.
collective_rpc
(
"
init_device
"
)
self
.
collective_rpc
(
"
load_model
"
)
self
.
driver_worker
.
init_worker
(
all_kw
args
=
[
kwargs
])
self
.
driver_worker
.
init_device
(
)
self
.
driver_worker
.
load_model
(
)
def
_distributed_args
(
self
)
->
tuple
[
str
,
int
,
int
]:
"""Return (distributed_init_method, rank, local_rank)."""
...
...
@@ -101,16 +99,12 @@ class UniProcExecutor(ExecutorBase):
==
ReconfigureRankType
.
SHUTDOWN_CURRENT_RANK
):
self
.
shutdown
()
return
def
shutdown
(
self
)
->
None
:
if
worker
:
=
self
.
driver_worker
:
worker
.
shutdown
()
UniProcExecutorAsync
=
UniProcExecutor
class
ExecutorWithExternalLauncher
(
UniProcExecutor
):
"""An executor that uses external launchers to launch engines,
specially designed for torchrun-compatible launchers, for
...
...
@@ -128,8 +122,6 @@ class ExecutorWithExternalLauncher(UniProcExecutor):
and they don't need to synchronize the states with each other.
"""
uses_ray
:
bool
=
False
def
_init_executor
(
self
)
->
None
:
"""Initialize the worker and load the model."""
if
envs
.
VLLM_USE_V1
:
...
...
@@ -152,22 +144,12 @@ class ExecutorWithExternalLauncher(UniProcExecutor):
local_rank
=
int
(
os
.
environ
[
"LOCAL_RANK"
])
return
distributed_init_method
,
rank
,
local_rank
def
determine_num_available_blocks
(
self
)
->
tuple
[
int
,
int
]:
"""
Determine the number of available KV blocks.
Add an additional all_reduce to get the min across all ranks.
Note that even if we have the same `gpu_memory_utilization` and
`swap_space`, the available memory in every rank might still
differ because NCCL can take different amounts of memory in
different ranks. Therefore, it is necessary to test if all ranks
agree on the same KV cache configuration.
"""
a
,
b
=
super
().
determine_num_available_blocks
()
def
determine_available_memory
(
self
)
->
list
[
int
]:
# in bytes
# we need to get the min across all ranks.
memory
=
super
().
determine_available_memory
()
from
vllm.distributed.parallel_state
import
get_world_group
cpu_group
=
get_world_group
().
cpu_group
a_tensor
=
torch
.
tensor
([
a
],
device
=
"cpu"
,
dtype
=
torch
.
int64
)
b_tensor
=
torch
.
tensor
([
b
],
device
=
"cpu"
,
dtype
=
torch
.
int64
)
dist
.
all_reduce
(
a_tensor
,
group
=
cpu_group
,
op
=
dist
.
ReduceOp
.
MIN
)
dist
.
all_reduce
(
b_tensor
,
group
=
cpu_group
,
op
=
dist
.
ReduceOp
.
MIN
)
return
a_tensor
.
item
(),
b_tensor
.
item
()
memory_tensor
=
torch
.
tensor
([
memory
],
device
=
"cpu"
,
dtype
=
torch
.
int64
)
dist
.
all_reduce
(
memory_tensor
,
group
=
cpu_group
,
op
=
dist
.
ReduceOp
.
MIN
)
return
[
memory_tensor
.
item
()]
vllm/v1/worker/worker_base.py
View file @
647214f3
...
...
@@ -128,28 +128,6 @@ class WorkerBase:
def
execute_model
(
self
,
scheduler_output
:
SchedulerOutput
)
->
ModelRunnerOutput
:
raise
NotImplementedError
def
start_worker_execution_loop
(
self
)
->
None
:
"""Execute model loop in parallel worker.
You can stop the loop by executing a driver worker with an empty output.
See `stop_remote_worker_execution_loop` for more details.
"""
raise
NotImplementedError
(
"Dead V0 code"
)
def
determine_num_available_blocks
(
self
)
->
tuple
[
int
,
int
]:
"""Determine the number of available blocks for the GPU KV cache and
swappable CPU KV cache.
The implementation may run profiling or other heuristics to determine
the size of caches.
Returns a tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks
are blocks that are "active" on the device and can be appended to.
num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be
appended to.
"""
raise
NotImplementedError
def
get_cache_block_size_bytes
(
self
)
->
int
:
"""Return the size of a single cache block, in bytes. Used in
speculative decoding.
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment