Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
94744ba4
Unverified
Commit
94744ba4
authored
Mar 29, 2025
by
wwl2755
Committed by
GitHub
Mar 29, 2025
Browse files
[V1] [Feature] Collective RPC (#15444)
Signed-off-by:
wwl2755
<
wangwenlong2755@gmail.com
>
parent
4965ec42
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
86 additions
and
10 deletions
+86
-10
.buildkite/test-pipeline.yaml
.buildkite/test-pipeline.yaml
+3
-3
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+11
-2
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+2
-2
vllm/v1/engine/core.py
vllm/v1/engine/core.py
+11
-1
vllm/v1/engine/core_client.py
vllm/v1/engine/core_client.py
+42
-1
vllm/v1/engine/llm_engine.py
vllm/v1/engine/llm_engine.py
+9
-1
vllm/v1/serial_utils.py
vllm/v1/serial_utils.py
+8
-0
No files found.
.buildkite/test-pipeline.yaml
View file @
94744ba4
...
...
@@ -150,8 +150,8 @@ steps:
# TODO: create a dedicated test section for multi-GPU example tests
# when we have multiple distributed example tests
-
pushd ../examples/offline_inference
-
VLLM_ENABLE_V1_MULTIPROCESSING=0
python3 rlhf.py
-
VLLM_ENABLE_V1_MULTIPROCESSING=0
RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py
-
python3 rlhf.py
-
RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py
-
popd
-
label
:
Metrics, Tracing Test
# 10min
...
...
@@ -520,7 +520,7 @@ steps:
-
vllm/v1/engine/
commands
:
-
TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py
-
VLLM_ENABLE_V1_MULTIPROCESSING=0
pytest -v -s entrypoints/llm/test_collective_rpc.py
-
pytest -v -s entrypoints/llm/test_collective_rpc.py
-
pytest -v -s ./compile/test_basic_correctness.py
-
pytest -v -s ./compile/test_wrapper.py
-
VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep 'Same node test passed'
...
...
vllm/engine/llm_engine.py
View file @
94744ba4
...
...
@@ -7,8 +7,8 @@ from collections import deque
from
contextlib
import
contextmanager
from
dataclasses
import
dataclass
from
functools
import
partial
from
typing
import
(
TYPE_CHECKING
,
Callable
,
ClassVar
,
Deque
,
Dict
,
Iterable
,
List
,
Mapping
,
NamedTuple
,
Optional
)
from
typing
import
(
TYPE_CHECKING
,
Any
,
Callable
,
ClassVar
,
Deque
,
Dict
,
Iterable
,
List
,
Mapping
,
NamedTuple
,
Optional
)
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Set
,
Type
,
Union
,
cast
,
overload
...
...
@@ -67,6 +67,7 @@ _LOCAL_LOGGING_INTERVAL_SEC = 5
_G
=
TypeVar
(
"_G"
,
bound
=
BaseTokenizerGroup
,
default
=
BaseTokenizerGroup
)
_O
=
TypeVar
(
"_O"
,
RequestOutput
,
PoolingRequestOutput
)
_R
=
TypeVar
(
"_R"
,
default
=
Any
)
@
dataclass
...
...
@@ -2123,6 +2124,14 @@ class LLMEngine:
return
sampling_params
def
collective_rpc
(
self
,
method
:
Union
[
str
,
Callable
[...,
_R
]],
timeout
:
Optional
[
float
]
=
None
,
args
:
tuple
=
(),
kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
)
->
list
[
_R
]:
return
self
.
model_executor
.
collective_rpc
(
method
,
timeout
,
args
,
kwargs
)
if
envs
.
is_set
(
"VLLM_USE_V1"
)
and
envs
.
VLLM_USE_V1
:
from
vllm.v1.engine.llm_engine
import
LLMEngine
as
V1LLMEngine
...
...
vllm/entrypoints/llm.py
View file @
94744ba4
...
...
@@ -492,8 +492,8 @@ class LLM:
It is recommended to use this API to only pass control messages,
and set up data-plane communication to pass data.
"""
executor
=
self
.
llm_engine
.
model_executor
return
executor
.
collective_rpc
(
method
,
timeout
,
args
,
kwargs
)
return
self
.
llm_engine
.
collective_rpc
(
method
,
timeout
,
args
,
kwargs
)
def
apply_model
(
self
,
func
:
Callable
[[
nn
.
Module
],
_R
])
->
list
[
_R
]:
"""
...
...
vllm/v1/engine/core.py
View file @
94744ba4
...
...
@@ -8,7 +8,7 @@ import time
from
concurrent.futures
import
Future
from
inspect
import
isclass
,
signature
from
logging
import
DEBUG
from
typing
import
Any
,
Optional
from
typing
import
Any
,
Callable
,
Optional
,
TypeVar
,
Union
import
msgspec
import
psutil
...
...
@@ -43,6 +43,8 @@ logger = init_logger(__name__)
POLLING_TIMEOUT_S
=
2.5
_R
=
TypeVar
(
'_R'
)
# Return type for collective_rpc
class
EngineCore
:
"""Inner loop of vLLM's Engine."""
...
...
@@ -280,6 +282,14 @@ class EngineCore:
def
pin_lora
(
self
,
lora_id
:
int
)
->
bool
:
return
self
.
model_executor
.
pin_lora
(
lora_id
)
def
collective_rpc
(
self
,
method
:
Union
[
str
,
Callable
[...,
_R
]],
timeout
:
Optional
[
float
]
=
None
,
args
:
tuple
=
(),
kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
)
->
list
[
_R
]:
return
self
.
model_executor
.
collective_rpc
(
method
,
timeout
,
args
,
kwargs
)
class
EngineCoreProc
(
EngineCore
):
"""ZMQ-wrapper for running EngineCore in background process."""
...
...
vllm/v1/engine/core_client.py
View file @
94744ba4
...
...
@@ -12,7 +12,7 @@ from collections.abc import Awaitable, Sequence
from
concurrent.futures
import
Future
from
dataclasses
import
dataclass
,
field
from
threading
import
Thread
from
typing
import
Any
,
Callable
,
Optional
,
Union
from
typing
import
Any
,
Callable
,
Optional
,
TypeVar
,
Union
import
zmq
import
zmq.asyncio
...
...
@@ -33,6 +33,8 @@ logger = init_logger(__name__)
AnyFuture
=
Union
[
asyncio
.
Future
[
Any
],
Future
[
Any
]]
_R
=
TypeVar
(
'_R'
)
# Return type for collective_rpc
class
EngineCoreClient
(
ABC
):
"""
...
...
@@ -117,6 +119,13 @@ class EngineCoreClient(ABC):
def
pin_lora
(
self
,
lora_id
:
int
)
->
bool
:
raise
NotImplementedError
def
collective_rpc
(
self
,
method
:
Union
[
str
,
Callable
[...,
_R
]],
timeout
:
Optional
[
float
]
=
None
,
args
:
tuple
=
(),
kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
)
->
list
[
_R
]:
raise
NotImplementedError
async
def
get_output_async
(
self
)
->
EngineCoreOutputs
:
raise
NotImplementedError
...
...
@@ -153,6 +162,14 @@ class EngineCoreClient(ABC):
async
def
pin_lora_async
(
self
,
lora_id
:
int
)
->
bool
:
raise
NotImplementedError
async
def
collective_rpc_async
(
self
,
method
:
Union
[
str
,
Callable
[...,
_R
]],
timeout
:
Optional
[
float
]
=
None
,
args
:
tuple
=
(),
kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
)
->
list
[
_R
]:
raise
NotImplementedError
class
InprocClient
(
EngineCoreClient
):
"""
...
...
@@ -210,6 +227,13 @@ class InprocClient(EngineCoreClient):
def
pin_lora
(
self
,
lora_id
:
int
)
->
bool
:
return
self
.
engine_core
.
pin_lora
(
lora_id
)
def
collective_rpc
(
self
,
method
:
Union
[
str
,
Callable
[...,
_R
]],
timeout
:
Optional
[
float
]
=
None
,
args
:
tuple
=
(),
kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
)
->
list
[
_R
]:
return
self
.
engine_core
.
collective_rpc
(
method
,
timeout
,
args
,
kwargs
)
class
CoreEngine
:
"""One per data parallel rank."""
...
...
@@ -505,6 +529,14 @@ class SyncMPClient(MPClient):
def
execute_dummy_batch
(
self
)
->
None
:
self
.
call_utility
(
"execute_dummy_batch"
)
def
collective_rpc
(
self
,
method
:
Union
[
str
,
Callable
[...,
_R
]],
timeout
:
Optional
[
float
]
=
None
,
args
:
tuple
=
(),
kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
)
->
list
[
_R
]:
return
self
.
call_utility
(
"collective_rpc"
,
method
,
timeout
,
args
,
kwargs
)
class
AsyncMPClient
(
MPClient
):
"""Asyncio-compatible client for multi-proc EngineCore."""
...
...
@@ -636,6 +668,15 @@ class AsyncMPClient(MPClient):
async
def
pin_lora_async
(
self
,
lora_id
:
int
)
->
bool
:
return
await
self
.
call_utility_async
(
"pin_lora"
,
lora_id
)
async
def
collective_rpc_async
(
self
,
method
:
Union
[
str
,
Callable
[...,
_R
]],
timeout
:
Optional
[
float
]
=
None
,
args
:
tuple
=
(),
kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
)
->
list
[
_R
]:
return
await
self
.
call_utility_async
(
"collective_rpc"
,
method
,
timeout
,
args
,
kwargs
)
class
DPAsyncMPClient
(
AsyncMPClient
):
"""Asyncio-compatible client for multi-proc, multi-engine (data parallel)
...
...
vllm/v1/engine/llm_engine.py
View file @
94744ba4
...
...
@@ -2,7 +2,7 @@
from
collections.abc
import
Mapping
from
copy
import
copy
from
typing
import
Optional
,
Union
from
typing
import
Any
,
Callable
,
Optional
,
Union
from
typing_extensions
import
TypeVar
...
...
@@ -32,6 +32,7 @@ from vllm.v1.executor.abstract import Executor
logger
=
init_logger
(
__name__
)
_G
=
TypeVar
(
"_G"
,
bound
=
BaseTokenizerGroup
,
default
=
BaseTokenizerGroup
)
_R
=
TypeVar
(
"_R"
,
default
=
Any
)
class
LLMEngine
:
...
...
@@ -282,6 +283,13 @@ class LLMEngine:
"""Prevent an adapter from being evicted."""
return
self
.
engine_core
.
pin_lora
(
lora_id
)
def
collective_rpc
(
self
,
method
:
Union
[
str
,
Callable
[...,
_R
]],
timeout
:
Optional
[
float
]
=
None
,
args
:
tuple
=
(),
kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
)
->
list
[
_R
]:
return
self
.
engine_core
.
collective_rpc
(
method
,
timeout
,
args
,
kwargs
)
def
__del__
(
self
):
if
dp_group
:
=
getattr
(
self
,
"dp_group"
,
None
):
stateless_destroy_torch_distributed_process_group
(
dp_group
)
vllm/v1/serial_utils.py
View file @
94744ba4
# SPDX-License-Identifier: Apache-2.0
import
pickle
from
types
import
FunctionType
from
typing
import
Any
,
Optional
import
cloudpickle
import
torch
from
msgspec
import
msgpack
CUSTOM_TYPE_TENSOR
=
1
CUSTOM_TYPE_PICKLE
=
2
CUSTOM_TYPE_CLOUDPICKLE
=
3
class
MsgpackEncoder
:
...
...
@@ -41,6 +44,9 @@ def custom_enc_hook(obj: Any) -> Any:
# https://gist.github.com/tlrmchlsmth/8067f1b24a82b6e2f90450e7764fa103 # noqa: E501
return
msgpack
.
Ext
(
CUSTOM_TYPE_TENSOR
,
pickle
.
dumps
(
obj
.
numpy
()))
if
isinstance
(
obj
,
FunctionType
):
return
msgpack
.
Ext
(
CUSTOM_TYPE_CLOUDPICKLE
,
cloudpickle
.
dumps
(
obj
))
return
msgpack
.
Ext
(
CUSTOM_TYPE_PICKLE
,
pickle
.
dumps
(
obj
))
...
...
@@ -49,5 +55,7 @@ def custom_ext_hook(code: int, data: memoryview) -> Any:
return
torch
.
from_numpy
(
pickle
.
loads
(
data
))
if
code
==
CUSTOM_TYPE_PICKLE
:
return
pickle
.
loads
(
data
)
if
code
==
CUSTOM_TYPE_CLOUDPICKLE
:
return
cloudpickle
.
loads
(
data
)
raise
NotImplementedError
(
f
"Extension type code
{
code
}
is not supported"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment