Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
94744ba4
Unverified
Commit
94744ba4
authored
Mar 29, 2025
by
wwl2755
Committed by
GitHub
Mar 29, 2025
Browse files
[V1] [Feature] Collective RPC (#15444)
Signed-off-by:
wwl2755
<
wangwenlong2755@gmail.com
>
parent
4965ec42
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
86 additions
and
10 deletions
+86
-10
.buildkite/test-pipeline.yaml
.buildkite/test-pipeline.yaml
+3
-3
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+11
-2
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+2
-2
vllm/v1/engine/core.py
vllm/v1/engine/core.py
+11
-1
vllm/v1/engine/core_client.py
vllm/v1/engine/core_client.py
+42
-1
vllm/v1/engine/llm_engine.py
vllm/v1/engine/llm_engine.py
+9
-1
vllm/v1/serial_utils.py
vllm/v1/serial_utils.py
+8
-0
No files found.
.buildkite/test-pipeline.yaml
View file @
94744ba4
...
@@ -150,8 +150,8 @@ steps:
...
@@ -150,8 +150,8 @@ steps:
# TODO: create a dedicated test section for multi-GPU example tests
# TODO: create a dedicated test section for multi-GPU example tests
# when we have multiple distributed example tests
# when we have multiple distributed example tests
-
pushd ../examples/offline_inference
-
pushd ../examples/offline_inference
-
VLLM_ENABLE_V1_MULTIPROCESSING=0
python3 rlhf.py
-
python3 rlhf.py
-
VLLM_ENABLE_V1_MULTIPROCESSING=0
RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py
-
RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py
-
popd
-
popd
-
label
:
Metrics, Tracing Test
# 10min
-
label
:
Metrics, Tracing Test
# 10min
...
@@ -520,7 +520,7 @@ steps:
...
@@ -520,7 +520,7 @@ steps:
-
vllm/v1/engine/
-
vllm/v1/engine/
commands
:
commands
:
-
TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py
-
TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py
-
VLLM_ENABLE_V1_MULTIPROCESSING=0
pytest -v -s entrypoints/llm/test_collective_rpc.py
-
pytest -v -s entrypoints/llm/test_collective_rpc.py
-
pytest -v -s ./compile/test_basic_correctness.py
-
pytest -v -s ./compile/test_basic_correctness.py
-
pytest -v -s ./compile/test_wrapper.py
-
pytest -v -s ./compile/test_wrapper.py
-
VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep 'Same node test passed'
-
VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep 'Same node test passed'
...
...
vllm/engine/llm_engine.py
View file @
94744ba4
...
@@ -7,8 +7,8 @@ from collections import deque
...
@@ -7,8 +7,8 @@ from collections import deque
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
functools
import
partial
from
functools
import
partial
from
typing
import
(
TYPE_CHECKING
,
Callable
,
ClassVar
,
Deque
,
Dict
,
Iterable
,
from
typing
import
(
TYPE_CHECKING
,
Any
,
Callable
,
ClassVar
,
Deque
,
Dict
,
List
,
Mapping
,
NamedTuple
,
Optional
)
Iterable
,
List
,
Mapping
,
NamedTuple
,
Optional
)
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Set
,
Type
,
Union
,
cast
,
overload
from
typing
import
Set
,
Type
,
Union
,
cast
,
overload
...
@@ -67,6 +67,7 @@ _LOCAL_LOGGING_INTERVAL_SEC = 5
...
@@ -67,6 +67,7 @@ _LOCAL_LOGGING_INTERVAL_SEC = 5
_G
=
TypeVar
(
"_G"
,
bound
=
BaseTokenizerGroup
,
default
=
BaseTokenizerGroup
)
_G
=
TypeVar
(
"_G"
,
bound
=
BaseTokenizerGroup
,
default
=
BaseTokenizerGroup
)
_O
=
TypeVar
(
"_O"
,
RequestOutput
,
PoolingRequestOutput
)
_O
=
TypeVar
(
"_O"
,
RequestOutput
,
PoolingRequestOutput
)
_R
=
TypeVar
(
"_R"
,
default
=
Any
)
@
dataclass
@
dataclass
...
@@ -2123,6 +2124,14 @@ class LLMEngine:
...
@@ -2123,6 +2124,14 @@ class LLMEngine:
return
sampling_params
return
sampling_params
def
collective_rpc
(
self
,
method
:
Union
[
str
,
Callable
[...,
_R
]],
timeout
:
Optional
[
float
]
=
None
,
args
:
tuple
=
(),
kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
)
->
list
[
_R
]:
return
self
.
model_executor
.
collective_rpc
(
method
,
timeout
,
args
,
kwargs
)
if
envs
.
is_set
(
"VLLM_USE_V1"
)
and
envs
.
VLLM_USE_V1
:
if
envs
.
is_set
(
"VLLM_USE_V1"
)
and
envs
.
VLLM_USE_V1
:
from
vllm.v1.engine.llm_engine
import
LLMEngine
as
V1LLMEngine
from
vllm.v1.engine.llm_engine
import
LLMEngine
as
V1LLMEngine
...
...
vllm/entrypoints/llm.py
View file @
94744ba4
...
@@ -492,8 +492,8 @@ class LLM:
...
@@ -492,8 +492,8 @@ class LLM:
It is recommended to use this API to only pass control messages,
It is recommended to use this API to only pass control messages,
and set up data-plane communication to pass data.
and set up data-plane communication to pass data.
"""
"""
executor
=
self
.
llm_engine
.
model_executor
return
executor
.
collective_rpc
(
method
,
timeout
,
args
,
kwargs
)
return
self
.
llm_engine
.
collective_rpc
(
method
,
timeout
,
args
,
kwargs
)
def
apply_model
(
self
,
func
:
Callable
[[
nn
.
Module
],
_R
])
->
list
[
_R
]:
def
apply_model
(
self
,
func
:
Callable
[[
nn
.
Module
],
_R
])
->
list
[
_R
]:
"""
"""
...
...
vllm/v1/engine/core.py
View file @
94744ba4
...
@@ -8,7 +8,7 @@ import time
...
@@ -8,7 +8,7 @@ import time
from
concurrent.futures
import
Future
from
concurrent.futures
import
Future
from
inspect
import
isclass
,
signature
from
inspect
import
isclass
,
signature
from
logging
import
DEBUG
from
logging
import
DEBUG
from
typing
import
Any
,
Optional
from
typing
import
Any
,
Callable
,
Optional
,
TypeVar
,
Union
import
msgspec
import
msgspec
import
psutil
import
psutil
...
@@ -43,6 +43,8 @@ logger = init_logger(__name__)
...
@@ -43,6 +43,8 @@ logger = init_logger(__name__)
POLLING_TIMEOUT_S
=
2.5
POLLING_TIMEOUT_S
=
2.5
_R
=
TypeVar
(
'_R'
)
# Return type for collective_rpc
class
EngineCore
:
class
EngineCore
:
"""Inner loop of vLLM's Engine."""
"""Inner loop of vLLM's Engine."""
...
@@ -280,6 +282,14 @@ class EngineCore:
...
@@ -280,6 +282,14 @@ class EngineCore:
def
pin_lora
(
self
,
lora_id
:
int
)
->
bool
:
def
pin_lora
(
self
,
lora_id
:
int
)
->
bool
:
return
self
.
model_executor
.
pin_lora
(
lora_id
)
return
self
.
model_executor
.
pin_lora
(
lora_id
)
def
collective_rpc
(
self
,
method
:
Union
[
str
,
Callable
[...,
_R
]],
timeout
:
Optional
[
float
]
=
None
,
args
:
tuple
=
(),
kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
)
->
list
[
_R
]:
return
self
.
model_executor
.
collective_rpc
(
method
,
timeout
,
args
,
kwargs
)
class
EngineCoreProc
(
EngineCore
):
class
EngineCoreProc
(
EngineCore
):
"""ZMQ-wrapper for running EngineCore in background process."""
"""ZMQ-wrapper for running EngineCore in background process."""
...
...
vllm/v1/engine/core_client.py
View file @
94744ba4
...
@@ -12,7 +12,7 @@ from collections.abc import Awaitable, Sequence
...
@@ -12,7 +12,7 @@ from collections.abc import Awaitable, Sequence
from
concurrent.futures
import
Future
from
concurrent.futures
import
Future
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
dataclass
,
field
from
threading
import
Thread
from
threading
import
Thread
from
typing
import
Any
,
Callable
,
Optional
,
Union
from
typing
import
Any
,
Callable
,
Optional
,
TypeVar
,
Union
import
zmq
import
zmq
import
zmq.asyncio
import
zmq.asyncio
...
@@ -33,6 +33,8 @@ logger = init_logger(__name__)
...
@@ -33,6 +33,8 @@ logger = init_logger(__name__)
AnyFuture
=
Union
[
asyncio
.
Future
[
Any
],
Future
[
Any
]]
AnyFuture
=
Union
[
asyncio
.
Future
[
Any
],
Future
[
Any
]]
_R
=
TypeVar
(
'_R'
)
# Return type for collective_rpc
class
EngineCoreClient
(
ABC
):
class
EngineCoreClient
(
ABC
):
"""
"""
...
@@ -117,6 +119,13 @@ class EngineCoreClient(ABC):
...
@@ -117,6 +119,13 @@ class EngineCoreClient(ABC):
def
pin_lora
(
self
,
lora_id
:
int
)
->
bool
:
def
pin_lora
(
self
,
lora_id
:
int
)
->
bool
:
raise
NotImplementedError
raise
NotImplementedError
def
collective_rpc
(
self
,
method
:
Union
[
str
,
Callable
[...,
_R
]],
timeout
:
Optional
[
float
]
=
None
,
args
:
tuple
=
(),
kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
)
->
list
[
_R
]:
raise
NotImplementedError
async
def
get_output_async
(
self
)
->
EngineCoreOutputs
:
async
def
get_output_async
(
self
)
->
EngineCoreOutputs
:
raise
NotImplementedError
raise
NotImplementedError
...
@@ -153,6 +162,14 @@ class EngineCoreClient(ABC):
...
@@ -153,6 +162,14 @@ class EngineCoreClient(ABC):
async
def
pin_lora_async
(
self
,
lora_id
:
int
)
->
bool
:
async
def
pin_lora_async
(
self
,
lora_id
:
int
)
->
bool
:
raise
NotImplementedError
raise
NotImplementedError
async
def
collective_rpc_async
(
self
,
method
:
Union
[
str
,
Callable
[...,
_R
]],
timeout
:
Optional
[
float
]
=
None
,
args
:
tuple
=
(),
kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
)
->
list
[
_R
]:
raise
NotImplementedError
class
InprocClient
(
EngineCoreClient
):
class
InprocClient
(
EngineCoreClient
):
"""
"""
...
@@ -210,6 +227,13 @@ class InprocClient(EngineCoreClient):
...
@@ -210,6 +227,13 @@ class InprocClient(EngineCoreClient):
def
pin_lora
(
self
,
lora_id
:
int
)
->
bool
:
def
pin_lora
(
self
,
lora_id
:
int
)
->
bool
:
return
self
.
engine_core
.
pin_lora
(
lora_id
)
return
self
.
engine_core
.
pin_lora
(
lora_id
)
def
collective_rpc
(
self
,
method
:
Union
[
str
,
Callable
[...,
_R
]],
timeout
:
Optional
[
float
]
=
None
,
args
:
tuple
=
(),
kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
)
->
list
[
_R
]:
return
self
.
engine_core
.
collective_rpc
(
method
,
timeout
,
args
,
kwargs
)
class
CoreEngine
:
class
CoreEngine
:
"""One per data parallel rank."""
"""One per data parallel rank."""
...
@@ -505,6 +529,14 @@ class SyncMPClient(MPClient):
...
@@ -505,6 +529,14 @@ class SyncMPClient(MPClient):
def
execute_dummy_batch
(
self
)
->
None
:
def
execute_dummy_batch
(
self
)
->
None
:
self
.
call_utility
(
"execute_dummy_batch"
)
self
.
call_utility
(
"execute_dummy_batch"
)
def
collective_rpc
(
self
,
method
:
Union
[
str
,
Callable
[...,
_R
]],
timeout
:
Optional
[
float
]
=
None
,
args
:
tuple
=
(),
kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
)
->
list
[
_R
]:
return
self
.
call_utility
(
"collective_rpc"
,
method
,
timeout
,
args
,
kwargs
)
class
AsyncMPClient
(
MPClient
):
class
AsyncMPClient
(
MPClient
):
"""Asyncio-compatible client for multi-proc EngineCore."""
"""Asyncio-compatible client for multi-proc EngineCore."""
...
@@ -636,6 +668,15 @@ class AsyncMPClient(MPClient):
...
@@ -636,6 +668,15 @@ class AsyncMPClient(MPClient):
async
def
pin_lora_async
(
self
,
lora_id
:
int
)
->
bool
:
async
def
pin_lora_async
(
self
,
lora_id
:
int
)
->
bool
:
return
await
self
.
call_utility_async
(
"pin_lora"
,
lora_id
)
return
await
self
.
call_utility_async
(
"pin_lora"
,
lora_id
)
async
def
collective_rpc_async
(
self
,
method
:
Union
[
str
,
Callable
[...,
_R
]],
timeout
:
Optional
[
float
]
=
None
,
args
:
tuple
=
(),
kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
)
->
list
[
_R
]:
return
await
self
.
call_utility_async
(
"collective_rpc"
,
method
,
timeout
,
args
,
kwargs
)
class
DPAsyncMPClient
(
AsyncMPClient
):
class
DPAsyncMPClient
(
AsyncMPClient
):
"""Asyncio-compatible client for multi-proc, multi-engine (data parallel)
"""Asyncio-compatible client for multi-proc, multi-engine (data parallel)
...
...
vllm/v1/engine/llm_engine.py
View file @
94744ba4
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
from
collections.abc
import
Mapping
from
collections.abc
import
Mapping
from
copy
import
copy
from
copy
import
copy
from
typing
import
Optional
,
Union
from
typing
import
Any
,
Callable
,
Optional
,
Union
from
typing_extensions
import
TypeVar
from
typing_extensions
import
TypeVar
...
@@ -32,6 +32,7 @@ from vllm.v1.executor.abstract import Executor
...
@@ -32,6 +32,7 @@ from vllm.v1.executor.abstract import Executor
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
_G
=
TypeVar
(
"_G"
,
bound
=
BaseTokenizerGroup
,
default
=
BaseTokenizerGroup
)
_G
=
TypeVar
(
"_G"
,
bound
=
BaseTokenizerGroup
,
default
=
BaseTokenizerGroup
)
_R
=
TypeVar
(
"_R"
,
default
=
Any
)
class
LLMEngine
:
class
LLMEngine
:
...
@@ -282,6 +283,13 @@ class LLMEngine:
...
@@ -282,6 +283,13 @@ class LLMEngine:
"""Prevent an adapter from being evicted."""
"""Prevent an adapter from being evicted."""
return
self
.
engine_core
.
pin_lora
(
lora_id
)
return
self
.
engine_core
.
pin_lora
(
lora_id
)
def
collective_rpc
(
self
,
method
:
Union
[
str
,
Callable
[...,
_R
]],
timeout
:
Optional
[
float
]
=
None
,
args
:
tuple
=
(),
kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
)
->
list
[
_R
]:
return
self
.
engine_core
.
collective_rpc
(
method
,
timeout
,
args
,
kwargs
)
def
__del__
(
self
):
def
__del__
(
self
):
if
dp_group
:
=
getattr
(
self
,
"dp_group"
,
None
):
if
dp_group
:
=
getattr
(
self
,
"dp_group"
,
None
):
stateless_destroy_torch_distributed_process_group
(
dp_group
)
stateless_destroy_torch_distributed_process_group
(
dp_group
)
vllm/v1/serial_utils.py
View file @
94744ba4
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
pickle
import
pickle
from
types
import
FunctionType
from
typing
import
Any
,
Optional
from
typing
import
Any
,
Optional
import
cloudpickle
import
torch
import
torch
from
msgspec
import
msgpack
from
msgspec
import
msgpack
CUSTOM_TYPE_TENSOR
=
1
CUSTOM_TYPE_TENSOR
=
1
CUSTOM_TYPE_PICKLE
=
2
CUSTOM_TYPE_PICKLE
=
2
CUSTOM_TYPE_CLOUDPICKLE
=
3
class
MsgpackEncoder
:
class
MsgpackEncoder
:
...
@@ -41,6 +44,9 @@ def custom_enc_hook(obj: Any) -> Any:
...
@@ -41,6 +44,9 @@ def custom_enc_hook(obj: Any) -> Any:
# https://gist.github.com/tlrmchlsmth/8067f1b24a82b6e2f90450e7764fa103 # noqa: E501
# https://gist.github.com/tlrmchlsmth/8067f1b24a82b6e2f90450e7764fa103 # noqa: E501
return
msgpack
.
Ext
(
CUSTOM_TYPE_TENSOR
,
pickle
.
dumps
(
obj
.
numpy
()))
return
msgpack
.
Ext
(
CUSTOM_TYPE_TENSOR
,
pickle
.
dumps
(
obj
.
numpy
()))
if
isinstance
(
obj
,
FunctionType
):
return
msgpack
.
Ext
(
CUSTOM_TYPE_CLOUDPICKLE
,
cloudpickle
.
dumps
(
obj
))
return
msgpack
.
Ext
(
CUSTOM_TYPE_PICKLE
,
pickle
.
dumps
(
obj
))
return
msgpack
.
Ext
(
CUSTOM_TYPE_PICKLE
,
pickle
.
dumps
(
obj
))
...
@@ -49,5 +55,7 @@ def custom_ext_hook(code: int, data: memoryview) -> Any:
...
@@ -49,5 +55,7 @@ def custom_ext_hook(code: int, data: memoryview) -> Any:
return
torch
.
from_numpy
(
pickle
.
loads
(
data
))
return
torch
.
from_numpy
(
pickle
.
loads
(
data
))
if
code
==
CUSTOM_TYPE_PICKLE
:
if
code
==
CUSTOM_TYPE_PICKLE
:
return
pickle
.
loads
(
data
)
return
pickle
.
loads
(
data
)
if
code
==
CUSTOM_TYPE_CLOUDPICKLE
:
return
cloudpickle
.
loads
(
data
)
raise
NotImplementedError
(
f
"Extension type code
{
code
}
is not supported"
)
raise
NotImplementedError
(
f
"Extension type code
{
code
}
is not supported"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment