Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
87a0c076
Unverified
Commit
87a0c076
authored
Jan 17, 2025
by
youkaichao
Committed by
GitHub
Jan 17, 2025
Browse files
[core] allow callable in collective_rpc (#12151)
Signed-off-by:
youkaichao
<
youkaichao@gmail.com
>
parent
d4e61945
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
147 additions
and
50 deletions
+147
-50
.buildkite/test-pipeline.yaml
.buildkite/test-pipeline.yaml
+3
-1
tests/engine/test_custom_executor.py
tests/engine/test_custom_executor.py
+2
-2
tests/entrypoints/llm/test_collective_rpc.py
tests/entrypoints/llm/test_collective_rpc.py
+36
-0
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+14
-3
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+9
-5
vllm/executor/executor_base.py
vllm/executor/executor_base.py
+5
-4
vllm/executor/mp_distributed_executor.py
vllm/executor/mp_distributed_executor.py
+14
-7
vllm/executor/multiproc_worker_utils.py
vllm/executor/multiproc_worker_utils.py
+6
-6
vllm/executor/ray_distributed_executor.py
vllm/executor/ray_distributed_executor.py
+10
-4
vllm/executor/uniproc_executor.py
vllm/executor/uniproc_executor.py
+5
-9
vllm/utils.py
vllm/utils.py
+23
-0
vllm/v1/executor/multiproc_executor.py
vllm/v1/executor/multiproc_executor.py
+15
-4
vllm/worker/worker_base.py
vllm/worker/worker_base.py
+5
-5
No files found.
.buildkite/test-pipeline.yaml
View file @
87a0c076
...
@@ -107,7 +107,7 @@ steps:
...
@@ -107,7 +107,7 @@ steps:
source_file_dependencies
:
source_file_dependencies
:
-
vllm/
-
vllm/
commands
:
commands
:
-
pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_lazy_outlines.py --ignore=entrypoints/llm/test_generate.py --ignore=entrypoints/llm/test_generate_multiple_loras.py --ignore=entrypoints/llm/test_guided_generate.py
-
pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_lazy_outlines.py --ignore=entrypoints/llm/test_generate.py --ignore=entrypoints/llm/test_generate_multiple_loras.py --ignore=entrypoints/llm/test_guided_generate.py
--ignore=entrypoints/llm/test_collective_rpc.py
-
pytest -v -s entrypoints/llm/test_lazy_outlines.py
# it needs a clean process
-
pytest -v -s entrypoints/llm/test_lazy_outlines.py
# it needs a clean process
-
pytest -v -s entrypoints/llm/test_generate.py
# it needs a clean process
-
pytest -v -s entrypoints/llm/test_generate.py
# it needs a clean process
-
pytest -v -s entrypoints/llm/test_generate_multiple_loras.py
# it needs a clean process
-
pytest -v -s entrypoints/llm/test_generate_multiple_loras.py
# it needs a clean process
...
@@ -466,7 +466,9 @@ steps:
...
@@ -466,7 +466,9 @@ steps:
-
vllm/worker/worker_base.py
-
vllm/worker/worker_base.py
-
vllm/worker/worker.py
-
vllm/worker/worker.py
-
vllm/worker/model_runner.py
-
vllm/worker/model_runner.py
-
entrypoints/llm/test_collective_rpc.py
commands
:
commands
:
-
pytest -v -s entrypoints/llm/test_collective_rpc.py
-
torchrun --nproc-per-node=2 distributed/test_torchrun_example.py
-
torchrun --nproc-per-node=2 distributed/test_torchrun_example.py
-
pytest -v -s ./compile/test_basic_correctness.py
-
pytest -v -s ./compile/test_basic_correctness.py
-
pytest -v -s ./compile/test_wrapper.py
-
pytest -v -s ./compile/test_wrapper.py
...
...
tests/engine/test_custom_executor.py
View file @
87a0c076
import
asyncio
import
asyncio
import
os
import
os
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
Union
import
pytest
import
pytest
...
@@ -18,7 +18,7 @@ class Mock:
...
@@ -18,7 +18,7 @@ class Mock:
class
CustomUniExecutor
(
UniProcExecutor
):
class
CustomUniExecutor
(
UniProcExecutor
):
def
collective_rpc
(
self
,
def
collective_rpc
(
self
,
method
:
str
,
method
:
Union
[
str
,
Callable
]
,
timeout
:
Optional
[
float
]
=
None
,
timeout
:
Optional
[
float
]
=
None
,
args
:
Tuple
=
(),
args
:
Tuple
=
(),
kwargs
:
Optional
[
Dict
]
=
None
)
->
List
[
Any
]:
kwargs
:
Optional
[
Dict
]
=
None
)
->
List
[
Any
]:
...
...
tests/entrypoints/llm/test_collective_rpc.py
0 → 100644
View file @
87a0c076
import
pytest
from
vllm
import
LLM
from
...utils
import
fork_new_process_for_each_test
@
pytest
.
mark
.
parametrize
(
"tp_size"
,
[
1
,
2
])
@
pytest
.
mark
.
parametrize
(
"backend"
,
[
"mp"
,
"ray"
])
@
fork_new_process_for_each_test
def
test_collective_rpc
(
tp_size
,
backend
):
if
tp_size
==
1
and
backend
==
"ray"
:
pytest
.
skip
(
"Skip duplicate test case"
)
if
tp_size
==
1
:
backend
=
None
# intentionally define the method and class in the test function,
# to test if they can be serialized and sent to the workers
def
echo_rank
(
self
):
return
self
.
rank
from
vllm.worker.worker
import
Worker
class
MyWorker
(
Worker
):
def
echo_rank
(
self
):
return
self
.
rank
llm
=
LLM
(
model
=
"meta-llama/Llama-3.2-1B-Instruct"
,
enforce_eager
=
True
,
load_format
=
"dummy"
,
tensor_parallel_size
=
tp_size
,
distributed_executor_backend
=
backend
,
worker_cls
=
MyWorker
)
for
method
in
[
"echo_rank"
,
echo_rank
]:
assert
llm
.
collective_rpc
(
method
)
==
list
(
range
(
tp_size
))
vllm/engine/llm_engine.py
View file @
87a0c076
...
@@ -5,10 +5,10 @@ from collections import deque
...
@@ -5,10 +5,10 @@ from collections import deque
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
functools
import
partial
from
functools
import
partial
from
typing
import
(
TYPE_CHECKING
,
Callable
,
ClassVar
,
Deque
,
Dict
,
Iterable
,
from
typing
import
(
TYPE_CHECKING
,
Any
,
Callable
,
ClassVar
,
Deque
,
Dict
,
List
,
Mapping
,
NamedTuple
,
Optional
)
Iterable
,
List
,
Mapping
,
NamedTuple
,
Optional
)
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Set
,
Type
,
Union
,
cast
,
overload
from
typing
import
Set
,
Tuple
,
Type
,
Union
,
cast
,
overload
import
torch
import
torch
from
typing_extensions
import
TypeVar
,
deprecated
from
typing_extensions
import
TypeVar
,
deprecated
...
@@ -1816,6 +1816,17 @@ class LLMEngine:
...
@@ -1816,6 +1816,17 @@ class LLMEngine:
def
stop_profile
(
self
)
->
None
:
def
stop_profile
(
self
)
->
None
:
self
.
model_executor
.
stop_profile
()
self
.
model_executor
.
stop_profile
()
def
collective_rpc
(
self
,
method
:
Union
[
str
,
Callable
],
timeout
:
Optional
[
float
]
=
None
,
args
:
Tuple
=
(),
kwargs
:
Optional
[
Dict
]
=
None
)
->
List
[
Any
]:
"""
See LLM.collective_rpc for more details.
"""
return
self
.
model_executor
.
collective_rpc
(
method
,
timeout
,
args
,
kwargs
)
def
check_health
(
self
)
->
None
:
def
check_health
(
self
)
->
None
:
if
self
.
tokenizer
:
if
self
.
tokenizer
:
self
.
tokenizer
.
check_health
()
self
.
tokenizer
.
check_health
()
...
...
vllm/entrypoints/llm.py
View file @
87a0c076
import
itertools
import
itertools
import
warnings
import
warnings
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
typing
import
(
Any
,
ClassVar
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
,
Type
,
from
typing
import
(
Any
,
Callable
,
ClassVar
,
Dict
,
List
,
Optional
,
Sequence
,
Union
,
cast
,
overload
)
Tuple
,
Type
,
Union
,
cast
,
overload
)
import
cloudpickle
import
cloudpickle
from
tqdm
import
tqdm
from
tqdm
import
tqdm
...
@@ -464,7 +464,7 @@ class LLM:
...
@@ -464,7 +464,7 @@ class LLM:
return
self
.
engine_class
.
validate_outputs
(
outputs
,
RequestOutput
)
return
self
.
engine_class
.
validate_outputs
(
outputs
,
RequestOutput
)
def
collective_rpc
(
self
,
def
collective_rpc
(
self
,
method
:
str
,
method
:
Union
[
str
,
Callable
]
,
timeout
:
Optional
[
float
]
=
None
,
timeout
:
Optional
[
float
]
=
None
,
args
:
Tuple
=
(),
args
:
Tuple
=
(),
kwargs
:
Optional
[
Dict
]
=
None
)
->
List
[
Any
]:
kwargs
:
Optional
[
Dict
]
=
None
)
->
List
[
Any
]:
...
@@ -476,9 +476,13 @@ class LLM:
...
@@ -476,9 +476,13 @@ class LLM:
Then, users can call the new methods through this API.
Then, users can call the new methods through this API.
It is recommended to use this API to only pass control messages,
It is recommended to use this API to only pass control messages,
and set up data-plane communication to pass data.
and set up data-plane communication to pass data.
The method can also be a callable, which will be serialized
and sent to all workers to execute.
If the method is a callable, it should accept an additional
`self` argument, in addition to the arguments passed in `args`
and `kwargs`. The `self` argument will be the worker object.
"""
"""
return
self
.
llm_engine
.
model_executor
.
collective_rpc
(
return
self
.
llm_engine
.
collective_rpc
(
method
,
timeout
,
args
,
kwargs
)
method
,
timeout
,
args
,
kwargs
)
def
beam_search
(
def
beam_search
(
self
,
self
,
...
...
vllm/executor/executor_base.py
View file @
87a0c076
import
asyncio
import
asyncio
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
typing
import
Any
,
Awaitable
,
Dict
,
List
,
Optional
,
Set
,
Tuple
,
Union
from
typing
import
(
Any
,
Awaitable
,
Callable
,
Dict
,
List
,
Optional
,
Set
,
Tuple
,
Union
)
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
...
@@ -47,7 +48,7 @@ class ExecutorBase(ABC):
...
@@ -47,7 +48,7 @@ class ExecutorBase(ABC):
@
abstractmethod
@
abstractmethod
def
collective_rpc
(
self
,
def
collective_rpc
(
self
,
method
:
str
,
method
:
Union
[
str
,
Callable
]
,
timeout
:
Optional
[
float
]
=
None
,
timeout
:
Optional
[
float
]
=
None
,
args
:
Tuple
=
(),
args
:
Tuple
=
(),
kwargs
:
Optional
[
Dict
]
=
None
)
->
List
[
Any
]:
kwargs
:
Optional
[
Dict
]
=
None
)
->
List
[
Any
]:
...
@@ -260,7 +261,7 @@ class DistributedExecutorBase(ExecutorBase):
...
@@ -260,7 +261,7 @@ class DistributedExecutorBase(ExecutorBase):
raise
NotImplementedError
raise
NotImplementedError
def
collective_rpc
(
self
,
def
collective_rpc
(
self
,
method
:
str
,
method
:
Union
[
str
,
Callable
]
,
timeout
:
Optional
[
float
]
=
None
,
timeout
:
Optional
[
float
]
=
None
,
args
:
Tuple
=
(),
args
:
Tuple
=
(),
kwargs
:
Optional
[
Dict
]
=
None
)
->
List
[
Any
]:
kwargs
:
Optional
[
Dict
]
=
None
)
->
List
[
Any
]:
...
@@ -269,7 +270,7 @@ class DistributedExecutorBase(ExecutorBase):
...
@@ -269,7 +270,7 @@ class DistributedExecutorBase(ExecutorBase):
@
abstractmethod
@
abstractmethod
def
_run_workers
(
def
_run_workers
(
self
,
self
,
method
:
str
,
method
:
Union
[
str
,
Callable
]
,
*
args
,
*
args
,
async_run_tensor_parallel_workers_only
:
bool
=
False
,
async_run_tensor_parallel_workers_only
:
bool
=
False
,
max_concurrent_workers
:
Optional
[
int
]
=
None
,
max_concurrent_workers
:
Optional
[
int
]
=
None
,
...
...
vllm/executor/mp_distributed_executor.py
View file @
87a0c076
import
asyncio
import
asyncio
from
typing
import
Any
,
List
,
Optional
from
typing
import
Any
,
Callable
,
List
,
Optional
,
Union
import
cloudpickle
from
vllm.executor.executor_base
import
DistributedExecutorBase
from
vllm.executor.executor_base
import
DistributedExecutorBase
from
vllm.executor.multiproc_worker_utils
import
(
from
vllm.executor.multiproc_worker_utils
import
(
...
@@ -9,7 +11,7 @@ from vllm.logger import init_logger
...
@@ -9,7 +11,7 @@ from vllm.logger import init_logger
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.utils
import
(
_run_task_with_lock
,
get_distributed_init_method
,
from
vllm.utils
import
(
_run_task_with_lock
,
get_distributed_init_method
,
get_ip
,
get_open_port
,
make_async
)
get_ip
,
get_open_port
,
make_async
,
run_method
)
from
vllm.worker.worker_base
import
WorkerWrapperBase
from
vllm.worker.worker_base
import
WorkerWrapperBase
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -107,7 +109,7 @@ class MultiprocessingDistributedExecutor(DistributedExecutorBase):
...
@@ -107,7 +109,7 @@ class MultiprocessingDistributedExecutor(DistributedExecutorBase):
def
_run_workers
(
def
_run_workers
(
self
,
self
,
method
:
str
,
method
:
Union
[
str
,
Callable
]
,
*
args
,
*
args
,
async_run_tensor_parallel_workers_only
:
bool
=
False
,
async_run_tensor_parallel_workers_only
:
bool
=
False
,
max_concurrent_workers
:
Optional
[
int
]
=
None
,
max_concurrent_workers
:
Optional
[
int
]
=
None
,
...
@@ -121,6 +123,11 @@ class MultiprocessingDistributedExecutor(DistributedExecutorBase):
...
@@ -121,6 +123,11 @@ class MultiprocessingDistributedExecutor(DistributedExecutorBase):
It will also be run asynchronously and return a list of futures
It will also be run asynchronously and return a list of futures
rather than blocking on the results.
rather than blocking on the results.
"""
"""
if
isinstance
(
method
,
str
):
sent_method
=
method
else
:
sent_method
=
cloudpickle
.
dumps
(
method
)
del
method
if
max_concurrent_workers
:
if
max_concurrent_workers
:
raise
NotImplementedError
(
raise
NotImplementedError
(
...
@@ -129,18 +136,18 @@ class MultiprocessingDistributedExecutor(DistributedExecutorBase):
...
@@ -129,18 +136,18 @@ class MultiprocessingDistributedExecutor(DistributedExecutorBase):
if
async_run_tensor_parallel_workers_only
:
if
async_run_tensor_parallel_workers_only
:
# Run only non-driver workers and just return futures.
# Run only non-driver workers and just return futures.
return
[
return
[
worker
.
execute_method
(
method
,
*
args
,
**
kwargs
)
worker
.
execute_method
(
sent_
method
,
*
args
,
**
kwargs
)
for
worker
in
self
.
non_driver_workers
for
worker
in
self
.
non_driver_workers
]
]
# Start all remote workers first.
# Start all remote workers first.
worker_outputs
=
[
worker_outputs
=
[
worker
.
execute_method
(
method
,
*
args
,
**
kwargs
)
worker
.
execute_method
(
sent_
method
,
*
args
,
**
kwargs
)
for
worker
in
self
.
workers
for
worker
in
self
.
workers
]
]
driver_worker_
method
=
getattr
(
self
.
driver_worker
,
method
)
driver_worker_
output
=
run_method
(
self
.
driver_worker
,
sent_
method
,
driver_worker_output
=
driver_worker_method
(
*
args
,
**
kwargs
)
args
,
kwargs
)
# Get the results of the workers.
# Get the results of the workers.
return
[
driver_worker_output
return
[
driver_worker_output
...
...
vllm/executor/multiproc_worker_utils.py
View file @
87a0c076
...
@@ -15,7 +15,7 @@ import torch
...
@@ -15,7 +15,7 @@ import torch
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.triton_utils.importing
import
HAS_TRITON
from
vllm.triton_utils.importing
import
HAS_TRITON
from
vllm.utils
import
_check_multiproc_method
,
get_mp_context
from
vllm.utils
import
_check_multiproc_method
,
get_mp_context
,
run_method
if
HAS_TRITON
:
if
HAS_TRITON
:
from
vllm.triton_utils
import
maybe_set_triton_cache_manager
from
vllm.triton_utils
import
maybe_set_triton_cache_manager
...
@@ -169,7 +169,7 @@ class ProcessWorkerWrapper:
...
@@ -169,7 +169,7 @@ class ProcessWorkerWrapper:
self
.
process
.
start
()
self
.
process
.
start
()
def
_enqueue_task
(
self
,
future
:
Union
[
ResultFuture
,
asyncio
.
Future
],
def
_enqueue_task
(
self
,
future
:
Union
[
ResultFuture
,
asyncio
.
Future
],
method
:
str
,
args
,
kwargs
):
method
:
Union
[
str
,
bytes
]
,
args
,
kwargs
):
task_id
=
uuid
.
uuid4
()
task_id
=
uuid
.
uuid4
()
self
.
tasks
[
task_id
]
=
future
self
.
tasks
[
task_id
]
=
future
try
:
try
:
...
@@ -180,12 +180,13 @@ class ProcessWorkerWrapper:
...
@@ -180,12 +180,13 @@ class ProcessWorkerWrapper:
del
self
.
tasks
[
task_id
]
del
self
.
tasks
[
task_id
]
raise
ChildProcessError
(
"worker died"
)
from
e
raise
ChildProcessError
(
"worker died"
)
from
e
def
execute_method
(
self
,
method
:
str
,
*
args
,
**
kwargs
):
def
execute_method
(
self
,
method
:
Union
[
str
,
bytes
]
,
*
args
,
**
kwargs
):
future
:
ResultFuture
=
ResultFuture
()
future
:
ResultFuture
=
ResultFuture
()
self
.
_enqueue_task
(
future
,
method
,
args
,
kwargs
)
self
.
_enqueue_task
(
future
,
method
,
args
,
kwargs
)
return
future
return
future
async
def
execute_method_async
(
self
,
method
:
str
,
*
args
,
**
kwargs
):
async
def
execute_method_async
(
self
,
method
:
Union
[
str
,
bytes
],
*
args
,
**
kwargs
):
future
=
asyncio
.
get_running_loop
().
create_future
()
future
=
asyncio
.
get_running_loop
().
create_future
()
self
.
_enqueue_task
(
future
,
method
,
args
,
kwargs
)
self
.
_enqueue_task
(
future
,
method
,
args
,
kwargs
)
return
await
future
return
await
future
...
@@ -230,8 +231,7 @@ def _run_worker_process(
...
@@ -230,8 +231,7 @@ def _run_worker_process(
exception
=
None
exception
=
None
task_id
,
method
,
args
,
kwargs
=
items
task_id
,
method
,
args
,
kwargs
=
items
try
:
try
:
executor
=
getattr
(
worker
,
method
)
output
=
run_method
(
worker
,
method
,
args
,
kwargs
)
output
=
executor
(
*
args
,
**
kwargs
)
except
SystemExit
:
except
SystemExit
:
raise
raise
except
KeyboardInterrupt
:
except
KeyboardInterrupt
:
...
...
vllm/executor/ray_distributed_executor.py
View file @
87a0c076
...
@@ -2,8 +2,9 @@ import asyncio
...
@@ -2,8 +2,9 @@ import asyncio
import
os
import
os
from
collections
import
defaultdict
from
collections
import
defaultdict
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
List
,
Optional
,
Union
import
cloudpickle
import
msgspec
import
msgspec
import
vllm.envs
as
envs
import
vllm.envs
as
envs
...
@@ -410,7 +411,7 @@ class RayDistributedExecutor(DistributedExecutorBase):
...
@@ -410,7 +411,7 @@ class RayDistributedExecutor(DistributedExecutorBase):
def
_run_workers
(
def
_run_workers
(
self
,
self
,
method
:
str
,
method
:
Union
[
str
,
Callable
]
,
*
args
,
*
args
,
async_run_tensor_parallel_workers_only
:
bool
=
False
,
async_run_tensor_parallel_workers_only
:
bool
=
False
,
max_concurrent_workers
:
Optional
[
int
]
=
None
,
max_concurrent_workers
:
Optional
[
int
]
=
None
,
...
@@ -426,6 +427,11 @@ class RayDistributedExecutor(DistributedExecutorBase):
...
@@ -426,6 +427,11 @@ class RayDistributedExecutor(DistributedExecutorBase):
rather than blocking on the results.
rather than blocking on the results.
- args/kwargs: All workers share the same args/kwargs
- args/kwargs: All workers share the same args/kwargs
"""
"""
if
isinstance
(
method
,
str
):
sent_method
=
method
else
:
sent_method
=
cloudpickle
.
dumps
(
method
)
del
method
if
self
.
use_ray_spmd_worker
:
if
self
.
use_ray_spmd_worker
:
assert
not
async_run_tensor_parallel_workers_only
,
(
assert
not
async_run_tensor_parallel_workers_only
,
(
"async_run_tensor_parallel_workers_only is not supported for "
"async_run_tensor_parallel_workers_only is not supported for "
...
@@ -440,7 +446,7 @@ class RayDistributedExecutor(DistributedExecutorBase):
...
@@ -440,7 +446,7 @@ class RayDistributedExecutor(DistributedExecutorBase):
if
async_run_tensor_parallel_workers_only
:
if
async_run_tensor_parallel_workers_only
:
ray_workers
=
self
.
non_driver_workers
ray_workers
=
self
.
non_driver_workers
ray_worker_outputs
=
[
ray_worker_outputs
=
[
worker
.
execute_method
.
remote
(
method
,
*
args
,
**
kwargs
)
worker
.
execute_method
.
remote
(
sent_
method
,
*
args
,
**
kwargs
)
for
worker
in
ray_workers
for
worker
in
ray_workers
]
]
...
@@ -455,7 +461,7 @@ class RayDistributedExecutor(DistributedExecutorBase):
...
@@ -455,7 +461,7 @@ class RayDistributedExecutor(DistributedExecutorBase):
if
not
self
.
use_ray_spmd_worker
:
if
not
self
.
use_ray_spmd_worker
:
# Start the driver worker after all the ray workers.
# Start the driver worker after all the ray workers.
driver_worker_output
=
[
driver_worker_output
=
[
self
.
driver_worker
.
execute_method
(
method
,
*
args
,
**
kwargs
)
self
.
driver_worker
.
execute_method
(
sent_
method
,
*
args
,
**
kwargs
)
]
]
# Get the results of the ray workers.
# Get the results of the ray workers.
...
...
vllm/executor/uniproc_executor.py
View file @
87a0c076
import
os
import
os
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
...
@@ -7,7 +7,8 @@ import torch.distributed as dist
...
@@ -7,7 +7,8 @@ import torch.distributed as dist
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.executor.executor_base
import
ExecutorBase
from
vllm.executor.executor_base
import
ExecutorBase
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.utils
import
get_distributed_init_method
,
get_ip
,
get_open_port
from
vllm.utils
import
(
get_distributed_init_method
,
get_ip
,
get_open_port
,
run_method
)
from
vllm.worker.worker_base
import
WorkerWrapperBase
from
vllm.worker.worker_base
import
WorkerWrapperBase
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -39,18 +40,13 @@ class UniProcExecutor(ExecutorBase):
...
@@ -39,18 +40,13 @@ class UniProcExecutor(ExecutorBase):
self
.
collective_rpc
(
"load_model"
)
self
.
collective_rpc
(
"load_model"
)
def
collective_rpc
(
self
,
def
collective_rpc
(
self
,
method
:
str
,
method
:
Union
[
str
,
Callable
]
,
timeout
:
Optional
[
float
]
=
None
,
timeout
:
Optional
[
float
]
=
None
,
args
:
Tuple
=
(),
args
:
Tuple
=
(),
kwargs
:
Optional
[
Dict
]
=
None
)
->
List
[
Any
]:
kwargs
:
Optional
[
Dict
]
=
None
)
->
List
[
Any
]:
if
kwargs
is
None
:
if
kwargs
is
None
:
kwargs
=
{}
kwargs
=
{}
try
:
answer
=
run_method
(
self
.
driver_worker
,
method
,
args
,
kwargs
)
func
=
getattr
(
self
.
driver_worker
,
method
)
except
AttributeError
:
raise
NotImplementedError
(
f
"Method
{
method
}
is not implemented."
)
\
from
None
answer
=
func
(
*
args
,
**
kwargs
)
return
[
answer
]
return
[
answer
]
def
check_health
(
self
)
->
None
:
def
check_health
(
self
)
->
None
:
...
...
vllm/utils.py
View file @
87a0c076
...
@@ -36,6 +36,7 @@ from typing import (TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable,
...
@@ -36,6 +36,7 @@ from typing import (TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable,
overload
)
overload
)
from
uuid
import
uuid4
from
uuid
import
uuid4
import
cloudpickle
import
numpy
as
np
import
numpy
as
np
import
numpy.typing
as
npt
import
numpy.typing
as
npt
import
psutil
import
psutil
...
@@ -2166,3 +2167,25 @@ def bind_kv_cache(
...
@@ -2166,3 +2167,25 @@ def bind_kv_cache(
assert
len
(
forward_ctx
.
kv_cache
)
==
len
(
kv_cache
)
assert
len
(
forward_ctx
.
kv_cache
)
==
len
(
kv_cache
)
for
ve
,
ve_kv_cache
in
enumerate
(
kv_cache
):
for
ve
,
ve_kv_cache
in
enumerate
(
kv_cache
):
forward_ctx
.
kv_cache
[
ve
]
=
ve_kv_cache
[
kv_cache_idx
]
forward_ctx
.
kv_cache
[
ve
]
=
ve_kv_cache
[
kv_cache_idx
]
def
run_method
(
obj
:
Any
,
method
:
Union
[
str
,
bytes
,
Callable
],
args
:
Tuple
[
Any
],
kwargs
:
Dict
[
str
,
Any
])
->
Any
:
"""
Run a method of an object with the given arguments and keyword arguments.
If the method is string, it will be converted to a method using getattr.
If the method is serialized bytes and will be deserialized using
cloudpickle.
If the method is a callable, it will be called directly.
"""
if
isinstance
(
method
,
bytes
):
func
=
partial
(
cloudpickle
.
loads
(
method
),
obj
)
elif
isinstance
(
method
,
str
):
try
:
func
=
getattr
(
obj
,
method
)
except
AttributeError
:
raise
NotImplementedError
(
f
"Method
{
method
!
r
}
is not"
" implemented."
)
from
None
else
:
func
=
partial
(
method
,
obj
)
# type: ignore
return
func
(
*
args
,
**
kwargs
)
vllm/v1/executor/multiproc_executor.py
View file @
87a0c076
...
@@ -6,9 +6,11 @@ import time
...
@@ -6,9 +6,11 @@ import time
import
weakref
import
weakref
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
enum
import
Enum
,
auto
from
enum
import
Enum
,
auto
from
functools
import
partial
from
multiprocessing.process
import
BaseProcess
from
multiprocessing.process
import
BaseProcess
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
Union
import
cloudpickle
import
psutil
import
psutil
import
zmq
import
zmq
...
@@ -120,7 +122,7 @@ class MultiprocExecutor(Executor):
...
@@ -120,7 +122,7 @@ class MultiprocExecutor(Executor):
return
kv_cache_specs
[
0
]
return
kv_cache_specs
[
0
]
def
collective_rpc
(
self
,
def
collective_rpc
(
self
,
method
:
str
,
method
:
Union
[
str
,
Callable
]
,
timeout
:
Optional
[
float
]
=
None
,
timeout
:
Optional
[
float
]
=
None
,
args
:
Tuple
=
(),
args
:
Tuple
=
(),
kwargs
:
Optional
[
Dict
]
=
None
)
->
List
[
Any
]:
kwargs
:
Optional
[
Dict
]
=
None
)
->
List
[
Any
]:
...
@@ -141,7 +143,12 @@ class MultiprocExecutor(Executor):
...
@@ -141,7 +143,12 @@ class MultiprocExecutor(Executor):
kwargs
=
kwargs
or
{}
kwargs
=
kwargs
or
{}
try
:
try
:
self
.
rpc_broadcast_mq
.
enqueue
((
method
,
args
,
kwargs
))
if
isinstance
(
method
,
str
):
send_method
=
method
else
:
send_method
=
cloudpickle
.
dumps
(
method
,
protocol
=
pickle
.
HIGHEST_PROTOCOL
)
self
.
rpc_broadcast_mq
.
enqueue
((
send_method
,
args
,
kwargs
))
responses
=
[
None
]
*
self
.
world_size
responses
=
[
None
]
*
self
.
world_size
for
w
in
self
.
workers
:
for
w
in
self
.
workers
:
...
@@ -408,7 +415,11 @@ class WorkerProc:
...
@@ -408,7 +415,11 @@ class WorkerProc:
method
,
args
,
kwargs
=
self
.
rpc_broadcast_mq
.
dequeue
()
method
,
args
,
kwargs
=
self
.
rpc_broadcast_mq
.
dequeue
()
try
:
try
:
output
=
getattr
(
self
.
worker
,
method
)(
*
args
,
**
kwargs
)
if
isinstance
(
method
,
str
):
func
=
getattr
(
self
.
worker
,
method
)
elif
isinstance
(
method
,
bytes
):
func
=
partial
(
cloudpickle
.
loads
(
method
),
self
.
worker
)
output
=
func
(
*
args
,
**
kwargs
)
except
Exception
as
e
:
except
Exception
as
e
:
self
.
worker_response_mq
.
enqueue
(
self
.
worker_response_mq
.
enqueue
(
(
WorkerProc
.
ResponseStatus
.
FAILURE
,
e
))
(
WorkerProc
.
ResponseStatus
.
FAILURE
,
e
))
...
...
vllm/worker/worker_base.py
View file @
87a0c076
...
@@ -14,7 +14,8 @@ from vllm.lora.request import LoRARequest
...
@@ -14,7 +14,8 @@ from vllm.lora.request import LoRARequest
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.sequence
import
ExecuteModelRequest
,
IntermediateTensors
from
vllm.sequence
import
ExecuteModelRequest
,
IntermediateTensors
from
vllm.utils
import
(
enable_trace_function_call_for_thread
,
from
vllm.utils
import
(
enable_trace_function_call_for_thread
,
resolve_obj_by_qualname
,
update_environment_variables
)
resolve_obj_by_qualname
,
run_method
,
update_environment_variables
)
from
vllm.worker.model_runner_base
import
(
BroadcastableModelInput
,
from
vllm.worker.model_runner_base
import
(
BroadcastableModelInput
,
ModelRunnerBase
,
ModelRunnerBase
,
ModelRunnerInputBase
)
ModelRunnerInputBase
)
...
@@ -539,17 +540,16 @@ class WorkerWrapperBase:
...
@@ -539,17 +540,16 @@ class WorkerWrapperBase:
self
.
worker
=
worker_class
(
**
kwargs
)
self
.
worker
=
worker_class
(
**
kwargs
)
assert
self
.
worker
is
not
None
assert
self
.
worker
is
not
None
def
execute_method
(
self
,
method
:
str
,
*
args
,
**
kwargs
):
def
execute_method
(
self
,
method
:
Union
[
str
,
bytes
]
,
*
args
,
**
kwargs
):
try
:
try
:
target
=
self
if
self
.
worker
is
None
else
self
.
worker
target
=
self
if
self
.
worker
is
None
else
self
.
worker
executor
=
getattr
(
target
,
method
)
return
run_method
(
target
,
method
,
args
,
kwargs
)
return
executor
(
*
args
,
**
kwargs
)
except
Exception
as
e
:
except
Exception
as
e
:
# if the driver worker also execute methods,
# if the driver worker also execute methods,
# exceptions in the rest worker may cause deadlock in rpc like ray
# exceptions in the rest worker may cause deadlock in rpc like ray
# see https://github.com/vllm-project/vllm/issues/3455
# see https://github.com/vllm-project/vllm/issues/3455
# print the error and inform the user to solve the error
# print the error and inform the user to solve the error
msg
=
(
f
"Error executing method
{
method
}
. "
msg
=
(
f
"Error executing method
{
method
!
r
}
. "
"This might cause deadlock in distributed execution."
)
"This might cause deadlock in distributed execution."
)
logger
.
exception
(
msg
)
logger
.
exception
(
msg
)
raise
e
raise
e
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment