Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6a113d9a
Unverified
Commit
6a113d9a
authored
Sep 29, 2025
by
Aaron Pham
Committed by
GitHub
Sep 29, 2025
Browse files
[V0 Deprecation] Remove `vllm.worker` and update according imports (#25901)
parent
2e4fe48c
Changes
11
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
276 additions
and
327 deletions
+276
-327
tests/model_executor/model_loader/tensorizer_loader/conftest.py
...model_executor/model_loader/tensorizer_loader/conftest.py
+1
-1
tools/pre_commit/check_pickle_imports.py
tools/pre_commit/check_pickle_imports.py
+0
-1
vllm/executor/executor_base.py
vllm/executor/executor_base.py
+5
-5
vllm/executor/ray_utils.py
vllm/executor/ray_utils.py
+1
-1
vllm/executor/uniproc_executor.py
vllm/executor/uniproc_executor.py
+5
-5
vllm/platforms/cuda.py
vllm/platforms/cuda.py
+1
-11
vllm/platforms/rocm.py
vllm/platforms/rocm.py
+1
-11
vllm/v1/executor/multiproc_executor.py
vllm/v1/executor/multiproc_executor.py
+2
-2
vllm/v1/worker/worker_base.py
vllm/v1/worker/worker_base.py
+260
-11
vllm/worker/__init__.py
vllm/worker/__init__.py
+0
-0
vllm/worker/worker_base.py
vllm/worker/worker_base.py
+0
-279
No files found.
tests/model_executor/model_loader/tensorizer_loader/conftest.py
View file @
6a113d9a
...
...
@@ -10,7 +10,7 @@ from vllm.model_executor.model_loader import tensorizer as tensorizer_mod
from
vllm.model_executor.model_loader.tensorizer
import
TensorizerConfig
from
vllm.utils
import
get_distributed_init_method
,
get_ip
,
get_open_port
from
vllm.v1.executor.abstract
import
UniProcExecutor
from
vllm.worker.worker_base
import
WorkerWrapperBase
from
vllm.
v1.
worker.worker_base
import
WorkerWrapperBase
MODEL_REF
=
"facebook/opt-125m"
...
...
tools/pre_commit/check_pickle_imports.py
View file @
6a113d9a
...
...
@@ -36,7 +36,6 @@ ALLOWED_FILES = {
'benchmarks/cutlass_benchmarks/w8a8_benchmarks.py'
,
'benchmarks/cutlass_benchmarks/sparse_benchmarks.py'
,
# cloudpickle
'vllm/worker/worker_base.py'
,
'vllm/executor/mp_distributed_executor.py'
,
'vllm/executor/ray_distributed_executor.py'
,
'vllm/entrypoints/llm.py'
,
...
...
vllm/executor/executor_base.py
View file @
6a113d9a
...
...
@@ -19,7 +19,7 @@ from vllm.sequence import ExecuteModelRequest
from
vllm.tasks
import
SupportedTask
from
vllm.utils
import
make_async
from
vllm.v1.outputs
import
PoolerOutput
,
SamplerOutput
from
vllm.worker.worker_base
import
WorkerBase
from
vllm.
v1.
worker.worker_base
import
WorkerBase
logger
=
init_logger
(
__name__
)
...
...
vllm/executor/ray_utils.py
View file @
6a113d9a
...
...
@@ -16,7 +16,7 @@ from vllm.logger import init_logger
from
vllm.platforms
import
current_platform
from
vllm.sequence
import
ExecuteModelRequest
,
IntermediateTensors
from
vllm.utils
import
get_ip
from
vllm.worker.worker_base
import
WorkerWrapperBase
from
vllm.
v1.
worker.worker_base
import
WorkerWrapperBase
if
TYPE_CHECKING
:
from
vllm.v1.core.sched.output
import
SchedulerOutput
...
...
vllm/executor/uniproc_executor.py
View file @
6a113d9a
...
...
@@ -19,7 +19,7 @@ from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
from
vllm.v1.engine
import
ReconfigureDistributedRequest
,
ReconfigureRankType
from
vllm.v1.executor.utils
import
get_and_update_mm_cache
from
vllm.v1.outputs
import
AsyncModelRunnerOutput
from
vllm.worker.worker_base
import
WorkerWrapperBase
from
vllm.
v1.
worker.worker_base
import
WorkerWrapperBase
logger
=
init_logger
(
__name__
)
...
...
vllm/platforms/cuda.py
View file @
6a113d9a
...
...
@@ -110,17 +110,7 @@ class CudaPlatformBase(Platform):
model_config
=
vllm_config
.
model_config
if
parallel_config
.
worker_cls
==
"auto"
:
if
vllm_config
.
speculative_config
:
if
not
envs
.
VLLM_USE_V1
:
raise
NotImplementedError
(
"Speculative decoding is not supported on vLLM V0."
)
parallel_config
.
worker_cls
=
"vllm.v1.worker.gpu_worker.Worker"
else
:
if
envs
.
VLLM_USE_V1
:
parallel_config
.
worker_cls
=
\
"vllm.v1.worker.gpu_worker.Worker"
else
:
parallel_config
.
worker_cls
=
"vllm.worker.worker.Worker"
cache_config
=
vllm_config
.
cache_config
if
cache_config
and
cache_config
.
block_size
is
None
:
...
...
vllm/platforms/rocm.py
View file @
6a113d9a
...
...
@@ -327,17 +327,7 @@ class RocmPlatform(Platform):
cache_config
.
block_size
=
16
if
parallel_config
.
worker_cls
==
"auto"
:
if
vllm_config
.
speculative_config
:
if
not
use_v1
:
raise
NotImplementedError
(
"Speculative decoding is not supported on vLLM V0."
)
parallel_config
.
worker_cls
=
"vllm.v1.worker.gpu_worker.Worker"
else
:
if
use_v1
:
parallel_config
.
worker_cls
=
\
"vllm.v1.worker.gpu_worker.Worker"
else
:
parallel_config
.
worker_cls
=
"vllm.worker.worker.Worker"
# Aiter rms norm perform best when CUDA Graph capture is enabled.
if
(
use_v1
and
use_aiter_rms_norm
and
not
is_eager_execution
and
"-rms_norm"
not
in
compilation_config
.
custom_ops
):
...
...
vllm/v1/executor/multiproc_executor.py
View file @
6a113d9a
...
...
@@ -41,7 +41,7 @@ from vllm.v1.executor.abstract import Executor, FailureCallback
from
vllm.v1.executor.utils
import
get_and_update_mm_cache
from
vllm.v1.outputs
import
(
AsyncModelRunnerOutput
,
DraftTokenIds
,
ModelRunnerOutput
)
from
vllm.worker.worker_base
import
WorkerWrapperBase
from
vllm.
v1.
worker.worker_base
import
WorkerWrapperBase
logger
=
init_logger
(
__name__
)
...
...
vllm/v1/worker/worker_base.py
View file @
6a113d9a
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Optional
from
__future__
import
annotations
import
os
from
typing
import
Any
,
Callable
,
Optional
,
TypeVar
,
Union
import
torch
import
torch.nn
as
nn
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
,
set_current_vllm_config
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.utils
import
(
enable_trace_function_call_for_thread
,
resolve_obj_by_qualname
,
run_method
,
update_environment_variables
,
warn_for_unimplemented_methods
)
from
vllm.v1.kv_cache_interface
import
KVCacheSpec
from
vllm.
worker.worker_base
import
WorkerBase
as
WorkerBaseV0
from
vllm.
v1.outputs
import
SamplerOutput
logger
=
init_logger
(
__name__
)
_R
=
TypeVar
(
"_R"
)
class
WorkerBase
(
WorkerBaseV0
):
"""
Abstract class for v1 worker, mainly define some methods for v1.
For methods shared by v0 and v1, define them in v0 WorkerBase
@
warn_for_unimplemented_methods
class
WorkerBase
:
"""Worker interface that allows vLLM to cleanly separate implementations for
different hardware. Also abstracts control plane communication, e.g., to
communicate request metadata to other workers.
"""
def
__init__
(
...
...
@@ -27,7 +39,7 @@ class WorkerBase(WorkerBaseV0):
rank
:
int
,
distributed_init_method
:
str
,
is_driver_worker
:
bool
=
False
,
):
)
->
None
:
"""
Initialize common worker components.
...
...
@@ -39,8 +51,21 @@ class WorkerBase(WorkerBaseV0):
is_driver_worker: Whether this worker handles driver
responsibilities
"""
# Configuration storage
super
().
__init__
(
vllm_config
=
vllm_config
)
self
.
vllm_config
=
vllm_config
self
.
model_config
=
vllm_config
.
model_config
self
.
cache_config
=
vllm_config
.
cache_config
self
.
lora_config
=
vllm_config
.
lora_config
self
.
load_config
=
vllm_config
.
load_config
self
.
parallel_config
=
vllm_config
.
parallel_config
self
.
scheduler_config
=
vllm_config
.
scheduler_config
self
.
device_config
=
vllm_config
.
device_config
self
.
speculative_config
=
vllm_config
.
speculative_config
self
.
observability_config
=
vllm_config
.
observability_config
self
.
kv_transfer_config
=
vllm_config
.
kv_transfer_config
self
.
compilation_config
=
vllm_config
.
compilation_config
from
vllm.platforms
import
current_platform
self
.
current_platform
=
current_platform
self
.
parallel_config
.
rank
=
rank
self
.
local_rank
=
local_rank
...
...
@@ -63,3 +88,227 @@ class WorkerBase(WorkerBaseV0):
def
check_health
(
self
)
->
None
:
"""Basic health check (override for device-specific checks)."""
return
def
init_device
(
self
)
->
None
:
"""Initialize device state, such as loading the model or other on-device
memory allocations.
"""
raise
NotImplementedError
def
initialize_cache
(
self
,
num_gpu_blocks
:
int
,
num_cpu_blocks
:
int
)
->
None
:
"""Initialize the KV cache with the given size in blocks.
"""
raise
NotImplementedError
def
get_model
(
self
)
->
nn
.
Module
:
raise
NotImplementedError
def
apply_model
(
self
,
fn
:
Callable
[[
nn
.
Module
],
_R
])
->
_R
:
"""Apply a function on the model inside this worker."""
return
fn
(
self
.
get_model
())
def
load_model
(
self
)
->
None
:
"""Load model onto target device."""
raise
NotImplementedError
def
execute_model
(
self
,
execute_model_req
:
Optional
[
ExecuteModelRequest
]
=
None
)
->
Optional
[
list
[
SamplerOutput
]]:
raise
NotImplementedError
def
start_worker_execution_loop
(
self
)
->
None
:
"""Execute model loop in parallel worker.
You can stop the loop by executing a driver worker with an empty output.
See `stop_remote_worker_execution_loop` for more details.
"""
with
self
.
current_platform
.
inference_mode
():
while
True
:
output
=
self
.
execute_model
(
execute_model_req
=
None
)
if
output
is
None
:
return
None
def
determine_num_available_blocks
(
self
)
->
tuple
[
int
,
int
]:
"""Determine the number of available blocks for the GPU KV cache and
swappable CPU KV cache.
The implementation may run profiling or other heuristics to determine
the size of caches.
Returns a tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks
are blocks that are "active" on the device and can be appended to.
num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be
appended to.
"""
raise
NotImplementedError
def
get_cache_block_size_bytes
(
self
)
->
int
:
"""Return the size of a single cache block, in bytes. Used in
speculative decoding.
"""
raise
NotImplementedError
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
bool
:
raise
NotImplementedError
def
remove_lora
(
self
,
lora_id
:
int
)
->
bool
:
raise
NotImplementedError
def
pin_lora
(
self
,
lora_id
:
int
)
->
bool
:
raise
NotImplementedError
def
list_loras
(
self
)
->
set
[
int
]:
raise
NotImplementedError
@
property
def
vocab_size
(
self
)
->
int
:
"""Get vocabulary size from model configuration."""
return
self
.
model_config
.
get_vocab_size
()
def
shutdown
(
self
)
->
None
:
"""Clean up resources held by the worker."""
return
class
WorkerWrapperBase
:
"""
This class represents one process in an executor/engine. It is responsible
for lazily initializing the worker and handling the worker's lifecycle.
We first instantiate the WorkerWrapper, which remembers the worker module
and class name. Then, when we call `update_environment_variables`, and the
real initialization happens in `init_worker`.
"""
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
rpc_rank
:
int
=
0
,
)
->
None
:
"""
Initialize the worker wrapper with the given vllm_config and rpc_rank.
Note: rpc_rank is the rank of the worker in the executor. In most cases,
it is also the rank of the worker in the distributed group. However,
when multiple executors work together, they can be different.
e.g. in the case of SPMD-style offline inference with TP=2,
users can launch 2 engines/executors, each with only 1 worker.
All workers have rpc_rank=0, but they have different ranks in the TP
group.
"""
self
.
rpc_rank
=
rpc_rank
self
.
worker
:
Optional
[
WorkerBase
]
=
None
self
.
vllm_config
:
Optional
[
VllmConfig
]
=
None
# do not store this `vllm_config`, `init_worker` will set the final
# one. TODO: investigate if we can remove this field in
# `WorkerWrapperBase`, `init_cached_hf_modules` should be
# unnecessary now.
if
vllm_config
.
model_config
is
not
None
:
# it can be None in tests
trust_remote_code
=
vllm_config
.
model_config
.
trust_remote_code
if
trust_remote_code
:
# note: lazy import to avoid importing torch before initializing
from
vllm.utils
import
init_cached_hf_modules
init_cached_hf_modules
()
def
shutdown
(
self
)
->
None
:
if
self
.
worker
is
not
None
:
self
.
worker
.
shutdown
()
def
adjust_rank
(
self
,
rank_mapping
:
dict
[
int
,
int
])
->
None
:
"""
Adjust the rpc_rank based on the given mapping.
It is only used during the initialization of the executor,
to adjust the rpc_rank of workers after we create all workers.
"""
if
self
.
rpc_rank
in
rank_mapping
:
self
.
rpc_rank
=
rank_mapping
[
self
.
rpc_rank
]
def
update_environment_variables
(
self
,
envs_list
:
list
[
dict
[
str
,
str
]],
)
->
None
:
envs
=
envs_list
[
self
.
rpc_rank
]
key
=
'CUDA_VISIBLE_DEVICES'
if
key
in
envs
and
key
in
os
.
environ
:
# overwriting CUDA_VISIBLE_DEVICES is desired behavior
# suppress the warning in `update_environment_variables`
del
os
.
environ
[
key
]
update_environment_variables
(
envs
)
def
init_worker
(
self
,
all_kwargs
:
list
[
dict
[
str
,
Any
]])
->
None
:
"""
Here we inject some common logic before initializing the worker.
Arguments are passed to the worker class constructor.
"""
kwargs
=
all_kwargs
[
self
.
rpc_rank
]
self
.
vllm_config
=
kwargs
.
get
(
"vllm_config"
)
assert
self
.
vllm_config
is
not
None
,
(
"vllm_config is required to initialize the worker"
)
enable_trace_function_call_for_thread
(
self
.
vllm_config
)
from
vllm.plugins
import
load_general_plugins
load_general_plugins
()
if
isinstance
(
self
.
vllm_config
.
parallel_config
.
worker_cls
,
str
):
worker_class
=
resolve_obj_by_qualname
(
self
.
vllm_config
.
parallel_config
.
worker_cls
)
else
:
raise
ValueError
(
"passing worker_cls is no longer supported. Please pass keep the class in a separate module and pass the qualified name of the class as a string."
# noqa: E501
)
if
self
.
vllm_config
.
parallel_config
.
worker_extension_cls
:
worker_extension_cls
=
resolve_obj_by_qualname
(
self
.
vllm_config
.
parallel_config
.
worker_extension_cls
)
extended_calls
=
[]
if
worker_extension_cls
not
in
worker_class
.
__bases__
:
# check any conflicts between worker and worker_extension_cls
for
attr
in
dir
(
worker_extension_cls
):
if
attr
.
startswith
(
"__"
):
continue
assert
not
hasattr
(
worker_class
,
attr
),
(
f
"Worker class
{
worker_class
}
already has an attribute"
f
"
{
attr
}
, which conflicts with the worker"
f
" extension class
{
worker_extension_cls
}
."
)
if
callable
(
getattr
(
worker_extension_cls
,
attr
)):
extended_calls
.
append
(
attr
)
# dynamically inherit the worker extension class
worker_class
.
__bases__
=
worker_class
.
__bases__
+
(
worker_extension_cls
,
)
logger
.
info
(
"Injected %s into %s for extended collective_rpc calls %s"
,
worker_extension_cls
,
worker_class
,
extended_calls
)
with
set_current_vllm_config
(
self
.
vllm_config
):
# To make vLLM config available during worker initialization
self
.
worker
=
worker_class
(
**
kwargs
)
assert
self
.
worker
is
not
None
def
initialize_from_config
(
self
,
kv_cache_configs
:
list
[
Any
])
->
None
:
kv_cache_config
=
kv_cache_configs
[
self
.
rpc_rank
]
with
set_current_vllm_config
(
self
.
vllm_config
):
self
.
worker
.
initialize_from_config
(
kv_cache_config
)
# type: ignore
def
init_device
(
self
):
with
set_current_vllm_config
(
self
.
vllm_config
):
# To make vLLM config available during device initialization
self
.
worker
.
init_device
()
# type: ignore
def
execute_method
(
self
,
method
:
Union
[
str
,
bytes
],
*
args
,
**
kwargs
):
try
:
# method resolution order:
# if a method is defined in this class, it will be called directly.
# otherwise, since we define `__getattr__` and redirect attribute
# query to `self.worker`, the method will be called on the worker.
return
run_method
(
self
,
method
,
args
,
kwargs
)
except
Exception
as
e
:
# if the driver worker also execute methods,
# exceptions in the rest worker may cause deadlock in rpc like ray
# see https://github.com/vllm-project/vllm/issues/3455
# print the error and inform the user to solve the error
msg
=
(
f
"Error executing method
{
method
!
r
}
. "
"This might cause deadlock in distributed execution."
)
logger
.
exception
(
msg
)
raise
e
def
__getattr__
(
self
,
attr
):
return
getattr
(
self
.
worker
,
attr
)
vllm/worker/__init__.py
deleted
100644 → 0
View file @
2e4fe48c
vllm/worker/worker_base.py
deleted
100644 → 0
View file @
2e4fe48c
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
os
from
typing
import
(
Any
,
Callable
,
Dict
,
List
,
Optional
,
Set
,
Tuple
,
TypeVar
,
Union
)
import
cloudpickle
import
torch.nn
as
nn
from
vllm.config
import
VllmConfig
,
set_current_vllm_config
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.utils
import
(
enable_trace_function_call_for_thread
,
resolve_obj_by_qualname
,
run_method
,
update_environment_variables
,
warn_for_unimplemented_methods
)
from
vllm.v1.outputs
import
SamplerOutput
logger
=
init_logger
(
__name__
)
_R
=
TypeVar
(
"_R"
)
@
warn_for_unimplemented_methods
class
WorkerBase
:
"""Worker interface that allows vLLM to cleanly separate implementations for
different hardware. Also abstracts control plane communication, e.g., to
communicate request metadata to other workers.
"""
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
)
->
None
:
self
.
vllm_config
=
vllm_config
self
.
model_config
=
vllm_config
.
model_config
self
.
cache_config
=
vllm_config
.
cache_config
self
.
lora_config
=
vllm_config
.
lora_config
self
.
load_config
=
vllm_config
.
load_config
self
.
parallel_config
=
vllm_config
.
parallel_config
self
.
scheduler_config
=
vllm_config
.
scheduler_config
self
.
device_config
=
vllm_config
.
device_config
self
.
speculative_config
=
vllm_config
.
speculative_config
self
.
observability_config
=
vllm_config
.
observability_config
self
.
kv_transfer_config
=
vllm_config
.
kv_transfer_config
self
.
compilation_config
=
vllm_config
.
compilation_config
from
vllm.platforms
import
current_platform
self
.
current_platform
=
current_platform
def
init_device
(
self
)
->
None
:
"""Initialize device state, such as loading the model or other on-device
memory allocations.
"""
raise
NotImplementedError
def
initialize_cache
(
self
,
num_gpu_blocks
:
int
,
num_cpu_blocks
:
int
)
->
None
:
"""Initialize the KV cache with the given size in blocks.
"""
raise
NotImplementedError
def
get_model
(
self
)
->
nn
.
Module
:
raise
NotImplementedError
def
apply_model
(
self
,
fn
:
Callable
[[
nn
.
Module
],
_R
])
->
_R
:
"""Apply a function on the model inside this worker."""
return
fn
(
self
.
get_model
())
def
load_model
(
self
)
->
None
:
"""Load model onto target device."""
raise
NotImplementedError
def
execute_model
(
self
,
execute_model_req
:
Optional
[
ExecuteModelRequest
]
=
None
)
->
Optional
[
List
[
SamplerOutput
]]:
raise
NotImplementedError
def
start_worker_execution_loop
(
self
)
->
None
:
"""Execute model loop in parallel worker.
You can stop the loop by executing a driver worker with an empty output.
See `stop_remote_worker_execution_loop` for more details.
"""
with
self
.
current_platform
.
inference_mode
():
while
True
:
output
=
self
.
execute_model
(
execute_model_req
=
None
)
if
output
is
None
:
return
None
def
determine_num_available_blocks
(
self
)
->
Tuple
[
int
,
int
]:
"""Determine the number of available blocks for the GPU KV cache and
swappable CPU KV cache.
The implementation may run profiling or other heuristics to determine
the size of caches.
Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks
are blocks that are "active" on the device and can be appended to.
num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be
appended to.
"""
raise
NotImplementedError
def
get_cache_block_size_bytes
(
self
)
->
int
:
"""Return the size of a single cache block, in bytes. Used in
speculative decoding.
"""
raise
NotImplementedError
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
bool
:
raise
NotImplementedError
def
remove_lora
(
self
,
lora_id
:
int
)
->
bool
:
raise
NotImplementedError
def
pin_lora
(
self
,
lora_id
:
int
)
->
bool
:
raise
NotImplementedError
def
list_loras
(
self
)
->
Set
[
int
]:
raise
NotImplementedError
@
property
def
vocab_size
(
self
)
->
int
:
"""Get vocabulary size from model configuration."""
return
self
.
model_config
.
get_vocab_size
()
def
shutdown
(
self
)
->
None
:
"""Clean up resources held by the worker."""
return
class
WorkerWrapperBase
:
"""
This class represents one process in an executor/engine. It is responsible
for lazily initializing the worker and handling the worker's lifecycle.
We first instantiate the WorkerWrapper, which remembers the worker module
and class name. Then, when we call `update_environment_variables`, and the
real initialization happens in `init_worker`.
"""
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
rpc_rank
:
int
=
0
,
)
->
None
:
"""
Initialize the worker wrapper with the given vllm_config and rpc_rank.
Note: rpc_rank is the rank of the worker in the executor. In most cases,
it is also the rank of the worker in the distributed group. However,
when multiple executors work together, they can be different.
e.g. in the case of SPMD-style offline inference with TP=2,
users can launch 2 engines/executors, each with only 1 worker.
All workers have rpc_rank=0, but they have different ranks in the TP
group.
"""
self
.
rpc_rank
=
rpc_rank
self
.
worker
:
Optional
[
WorkerBase
]
=
None
self
.
vllm_config
:
Optional
[
VllmConfig
]
=
None
# do not store this `vllm_config`, `init_worker` will set the final
# one. TODO: investigate if we can remove this field in
# `WorkerWrapperBase`, `init_cached_hf_modules` should be
# unnecessary now.
if
vllm_config
.
model_config
is
not
None
:
# it can be None in tests
trust_remote_code
=
vllm_config
.
model_config
.
trust_remote_code
if
trust_remote_code
:
# note: lazy import to avoid importing torch before initializing
from
vllm.utils
import
init_cached_hf_modules
init_cached_hf_modules
()
def
shutdown
(
self
)
->
None
:
if
self
.
worker
is
not
None
:
self
.
worker
.
shutdown
()
def
adjust_rank
(
self
,
rank_mapping
:
Dict
[
int
,
int
])
->
None
:
"""
Adjust the rpc_rank based on the given mapping.
It is only used during the initialization of the executor,
to adjust the rpc_rank of workers after we create all workers.
"""
if
self
.
rpc_rank
in
rank_mapping
:
self
.
rpc_rank
=
rank_mapping
[
self
.
rpc_rank
]
def
update_environment_variables
(
self
,
envs_list
:
List
[
Dict
[
str
,
str
]])
->
None
:
envs
=
envs_list
[
self
.
rpc_rank
]
key
=
'CUDA_VISIBLE_DEVICES'
if
key
in
envs
and
key
in
os
.
environ
:
# overwriting CUDA_VISIBLE_DEVICES is desired behavior
# suppress the warning in `update_environment_variables`
del
os
.
environ
[
key
]
update_environment_variables
(
envs
)
def
init_worker
(
self
,
all_kwargs
:
List
[
Dict
[
str
,
Any
]])
->
None
:
"""
Here we inject some common logic before initializing the worker.
Arguments are passed to the worker class constructor.
"""
kwargs
=
all_kwargs
[
self
.
rpc_rank
]
self
.
vllm_config
=
kwargs
.
get
(
"vllm_config"
)
assert
self
.
vllm_config
is
not
None
,
(
"vllm_config is required to initialize the worker"
)
enable_trace_function_call_for_thread
(
self
.
vllm_config
)
from
vllm.plugins
import
load_general_plugins
load_general_plugins
()
if
isinstance
(
self
.
vllm_config
.
parallel_config
.
worker_cls
,
str
):
worker_class
=
resolve_obj_by_qualname
(
self
.
vllm_config
.
parallel_config
.
worker_cls
)
else
:
logger
.
warning
(
"passing worker_cls as a class object is strongly deprecated,"
" as the serialization of class objects can be tricky and"
" error-prone. To be safe, please keep the class in a separate"
" module and pass the qualified name of the class as a string."
)
assert
isinstance
(
self
.
vllm_config
.
parallel_config
.
worker_cls
,
bytes
)
worker_class
=
cloudpickle
.
loads
(
self
.
vllm_config
.
parallel_config
.
worker_cls
)
if
self
.
vllm_config
.
parallel_config
.
worker_extension_cls
:
worker_extension_cls
=
resolve_obj_by_qualname
(
self
.
vllm_config
.
parallel_config
.
worker_extension_cls
)
extended_calls
=
[]
if
worker_extension_cls
not
in
worker_class
.
__bases__
:
# check any conflicts between worker and worker_extension_cls
for
attr
in
dir
(
worker_extension_cls
):
if
attr
.
startswith
(
"__"
):
continue
assert
not
hasattr
(
worker_class
,
attr
),
(
f
"Worker class
{
worker_class
}
already has an attribute"
f
"
{
attr
}
, which conflicts with the worker"
f
" extension class
{
worker_extension_cls
}
."
)
if
callable
(
getattr
(
worker_extension_cls
,
attr
)):
extended_calls
.
append
(
attr
)
# dynamically inherit the worker extension class
worker_class
.
__bases__
=
worker_class
.
__bases__
+
(
worker_extension_cls
,
)
logger
.
info
(
"Injected %s into %s for extended collective_rpc calls %s"
,
worker_extension_cls
,
worker_class
,
extended_calls
)
with
set_current_vllm_config
(
self
.
vllm_config
):
# To make vLLM config available during worker initialization
self
.
worker
=
worker_class
(
**
kwargs
)
assert
self
.
worker
is
not
None
def
initialize_from_config
(
self
,
kv_cache_configs
:
List
[
Any
])
->
None
:
kv_cache_config
=
kv_cache_configs
[
self
.
rpc_rank
]
with
set_current_vllm_config
(
self
.
vllm_config
):
self
.
worker
.
initialize_from_config
(
kv_cache_config
)
# type: ignore
def
init_device
(
self
):
with
set_current_vllm_config
(
self
.
vllm_config
):
# To make vLLM config available during device initialization
self
.
worker
.
init_device
()
# type: ignore
def
execute_method
(
self
,
method
:
Union
[
str
,
bytes
],
*
args
,
**
kwargs
):
try
:
# method resolution order:
# if a method is defined in this class, it will be called directly.
# otherwise, since we define `__getattr__` and redirect attribute
# query to `self.worker`, the method will be called on the worker.
return
run_method
(
self
,
method
,
args
,
kwargs
)
except
Exception
as
e
:
# if the driver worker also execute methods,
# exceptions in the rest worker may cause deadlock in rpc like ray
# see https://github.com/vllm-project/vllm/issues/3455
# print the error and inform the user to solve the error
msg
=
(
f
"Error executing method
{
method
!
r
}
. "
"This might cause deadlock in distributed execution."
)
logger
.
exception
(
msg
)
raise
e
def
__getattr__
(
self
,
attr
):
return
getattr
(
self
.
worker
,
attr
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment