Unverified Commit 94744ba4 authored by wwl2755's avatar wwl2755 Committed by GitHub
Browse files

[V1] [Feature] Collective RPC (#15444)


Signed-off-by: default avatarwwl2755 <wangwenlong2755@gmail.com>
parent 4965ec42
...@@ -150,8 +150,8 @@ steps: ...@@ -150,8 +150,8 @@ steps:
# TODO: create a dedicated test section for multi-GPU example tests # TODO: create a dedicated test section for multi-GPU example tests
# when we have multiple distributed example tests # when we have multiple distributed example tests
- pushd ../examples/offline_inference - pushd ../examples/offline_inference
- VLLM_ENABLE_V1_MULTIPROCESSING=0 python3 rlhf.py - python3 rlhf.py
- VLLM_ENABLE_V1_MULTIPROCESSING=0 RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py - RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py
- popd - popd
- label: Metrics, Tracing Test # 10min - label: Metrics, Tracing Test # 10min
...@@ -520,7 +520,7 @@ steps: ...@@ -520,7 +520,7 @@ steps:
- vllm/v1/engine/ - vllm/v1/engine/
commands: commands:
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py - TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py
- VLLM_ENABLE_V1_MULTIPROCESSING=0 pytest -v -s entrypoints/llm/test_collective_rpc.py - pytest -v -s entrypoints/llm/test_collective_rpc.py
- pytest -v -s ./compile/test_basic_correctness.py - pytest -v -s ./compile/test_basic_correctness.py
- pytest -v -s ./compile/test_wrapper.py - pytest -v -s ./compile/test_wrapper.py
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep 'Same node test passed' - VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep 'Same node test passed'
......
...@@ -7,8 +7,8 @@ from collections import deque ...@@ -7,8 +7,8 @@ from collections import deque
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from functools import partial from functools import partial
from typing import (TYPE_CHECKING, Callable, ClassVar, Deque, Dict, Iterable, from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict,
List, Mapping, NamedTuple, Optional) Iterable, List, Mapping, NamedTuple, Optional)
from typing import Sequence as GenericSequence from typing import Sequence as GenericSequence
from typing import Set, Type, Union, cast, overload from typing import Set, Type, Union, cast, overload
...@@ -67,6 +67,7 @@ _LOCAL_LOGGING_INTERVAL_SEC = 5 ...@@ -67,6 +67,7 @@ _LOCAL_LOGGING_INTERVAL_SEC = 5
_G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup) _G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup)
_O = TypeVar("_O", RequestOutput, PoolingRequestOutput) _O = TypeVar("_O", RequestOutput, PoolingRequestOutput)
_R = TypeVar("_R", default=Any)
@dataclass @dataclass
...@@ -2123,6 +2124,14 @@ class LLMEngine: ...@@ -2123,6 +2124,14 @@ class LLMEngine:
return sampling_params return sampling_params
def collective_rpc(self,
method: Union[str, Callable[..., _R]],
timeout: Optional[float] = None,
args: tuple = (),
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
return self.model_executor.collective_rpc(method, timeout, args,
kwargs)
if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1: if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1:
from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
......
...@@ -492,8 +492,8 @@ class LLM: ...@@ -492,8 +492,8 @@ class LLM:
It is recommended to use this API to only pass control messages, It is recommended to use this API to only pass control messages,
and set up data-plane communication to pass data. and set up data-plane communication to pass data.
""" """
executor = self.llm_engine.model_executor
return executor.collective_rpc(method, timeout, args, kwargs) return self.llm_engine.collective_rpc(method, timeout, args, kwargs)
def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]: def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
""" """
......
...@@ -8,7 +8,7 @@ import time ...@@ -8,7 +8,7 @@ import time
from concurrent.futures import Future from concurrent.futures import Future
from inspect import isclass, signature from inspect import isclass, signature
from logging import DEBUG from logging import DEBUG
from typing import Any, Optional from typing import Any, Callable, Optional, TypeVar, Union
import msgspec import msgspec
import psutil import psutil
...@@ -43,6 +43,8 @@ logger = init_logger(__name__) ...@@ -43,6 +43,8 @@ logger = init_logger(__name__)
POLLING_TIMEOUT_S = 2.5 POLLING_TIMEOUT_S = 2.5
_R = TypeVar('_R') # Return type for collective_rpc
class EngineCore: class EngineCore:
"""Inner loop of vLLM's Engine.""" """Inner loop of vLLM's Engine."""
...@@ -280,6 +282,14 @@ class EngineCore: ...@@ -280,6 +282,14 @@ class EngineCore:
def pin_lora(self, lora_id: int) -> bool: def pin_lora(self, lora_id: int) -> bool:
return self.model_executor.pin_lora(lora_id) return self.model_executor.pin_lora(lora_id)
def collective_rpc(self,
method: Union[str, Callable[..., _R]],
timeout: Optional[float] = None,
args: tuple = (),
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
return self.model_executor.collective_rpc(method, timeout, args,
kwargs)
class EngineCoreProc(EngineCore): class EngineCoreProc(EngineCore):
"""ZMQ-wrapper for running EngineCore in background process.""" """ZMQ-wrapper for running EngineCore in background process."""
......
...@@ -12,7 +12,7 @@ from collections.abc import Awaitable, Sequence ...@@ -12,7 +12,7 @@ from collections.abc import Awaitable, Sequence
from concurrent.futures import Future from concurrent.futures import Future
from dataclasses import dataclass, field from dataclasses import dataclass, field
from threading import Thread from threading import Thread
from typing import Any, Callable, Optional, Union from typing import Any, Callable, Optional, TypeVar, Union
import zmq import zmq
import zmq.asyncio import zmq.asyncio
...@@ -33,6 +33,8 @@ logger = init_logger(__name__) ...@@ -33,6 +33,8 @@ logger = init_logger(__name__)
AnyFuture = Union[asyncio.Future[Any], Future[Any]] AnyFuture = Union[asyncio.Future[Any], Future[Any]]
_R = TypeVar('_R') # Return type for collective_rpc
class EngineCoreClient(ABC): class EngineCoreClient(ABC):
""" """
...@@ -117,6 +119,13 @@ class EngineCoreClient(ABC): ...@@ -117,6 +119,13 @@ class EngineCoreClient(ABC):
def pin_lora(self, lora_id: int) -> bool: def pin_lora(self, lora_id: int) -> bool:
raise NotImplementedError raise NotImplementedError
def collective_rpc(self,
method: Union[str, Callable[..., _R]],
timeout: Optional[float] = None,
args: tuple = (),
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
raise NotImplementedError
async def get_output_async(self) -> EngineCoreOutputs: async def get_output_async(self) -> EngineCoreOutputs:
raise NotImplementedError raise NotImplementedError
...@@ -153,6 +162,14 @@ class EngineCoreClient(ABC): ...@@ -153,6 +162,14 @@ class EngineCoreClient(ABC):
async def pin_lora_async(self, lora_id: int) -> bool: async def pin_lora_async(self, lora_id: int) -> bool:
raise NotImplementedError raise NotImplementedError
async def collective_rpc_async(
self,
method: Union[str, Callable[..., _R]],
timeout: Optional[float] = None,
args: tuple = (),
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
raise NotImplementedError
class InprocClient(EngineCoreClient): class InprocClient(EngineCoreClient):
""" """
...@@ -210,6 +227,13 @@ class InprocClient(EngineCoreClient): ...@@ -210,6 +227,13 @@ class InprocClient(EngineCoreClient):
def pin_lora(self, lora_id: int) -> bool: def pin_lora(self, lora_id: int) -> bool:
return self.engine_core.pin_lora(lora_id) return self.engine_core.pin_lora(lora_id)
def collective_rpc(self,
method: Union[str, Callable[..., _R]],
timeout: Optional[float] = None,
args: tuple = (),
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
return self.engine_core.collective_rpc(method, timeout, args, kwargs)
class CoreEngine: class CoreEngine:
"""One per data parallel rank.""" """One per data parallel rank."""
...@@ -505,6 +529,14 @@ class SyncMPClient(MPClient): ...@@ -505,6 +529,14 @@ class SyncMPClient(MPClient):
def execute_dummy_batch(self) -> None: def execute_dummy_batch(self) -> None:
self.call_utility("execute_dummy_batch") self.call_utility("execute_dummy_batch")
def collective_rpc(self,
method: Union[str, Callable[..., _R]],
timeout: Optional[float] = None,
args: tuple = (),
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
return self.call_utility("collective_rpc", method, timeout, args,
kwargs)
class AsyncMPClient(MPClient): class AsyncMPClient(MPClient):
"""Asyncio-compatible client for multi-proc EngineCore.""" """Asyncio-compatible client for multi-proc EngineCore."""
...@@ -636,6 +668,15 @@ class AsyncMPClient(MPClient): ...@@ -636,6 +668,15 @@ class AsyncMPClient(MPClient):
async def pin_lora_async(self, lora_id: int) -> bool: async def pin_lora_async(self, lora_id: int) -> bool:
return await self.call_utility_async("pin_lora", lora_id) return await self.call_utility_async("pin_lora", lora_id)
async def collective_rpc_async(
self,
method: Union[str, Callable[..., _R]],
timeout: Optional[float] = None,
args: tuple = (),
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
return await self.call_utility_async("collective_rpc", method, timeout,
args, kwargs)
class DPAsyncMPClient(AsyncMPClient): class DPAsyncMPClient(AsyncMPClient):
"""Asyncio-compatible client for multi-proc, multi-engine (data parallel) """Asyncio-compatible client for multi-proc, multi-engine (data parallel)
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
from collections.abc import Mapping from collections.abc import Mapping
from copy import copy from copy import copy
from typing import Optional, Union from typing import Any, Callable, Optional, Union
from typing_extensions import TypeVar from typing_extensions import TypeVar
...@@ -32,6 +32,7 @@ from vllm.v1.executor.abstract import Executor ...@@ -32,6 +32,7 @@ from vllm.v1.executor.abstract import Executor
logger = init_logger(__name__) logger = init_logger(__name__)
_G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup) _G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup)
_R = TypeVar("_R", default=Any)
class LLMEngine: class LLMEngine:
...@@ -282,6 +283,13 @@ class LLMEngine: ...@@ -282,6 +283,13 @@ class LLMEngine:
"""Prevent an adapter from being evicted.""" """Prevent an adapter from being evicted."""
return self.engine_core.pin_lora(lora_id) return self.engine_core.pin_lora(lora_id)
def collective_rpc(self,
method: Union[str, Callable[..., _R]],
timeout: Optional[float] = None,
args: tuple = (),
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
return self.engine_core.collective_rpc(method, timeout, args, kwargs)
def __del__(self): def __del__(self):
if dp_group := getattr(self, "dp_group", None): if dp_group := getattr(self, "dp_group", None):
stateless_destroy_torch_distributed_process_group(dp_group) stateless_destroy_torch_distributed_process_group(dp_group)
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import pickle import pickle
from types import FunctionType
from typing import Any, Optional from typing import Any, Optional
import cloudpickle
import torch import torch
from msgspec import msgpack from msgspec import msgpack
CUSTOM_TYPE_TENSOR = 1 CUSTOM_TYPE_TENSOR = 1
CUSTOM_TYPE_PICKLE = 2 CUSTOM_TYPE_PICKLE = 2
CUSTOM_TYPE_CLOUDPICKLE = 3
class MsgpackEncoder: class MsgpackEncoder:
...@@ -41,6 +44,9 @@ def custom_enc_hook(obj: Any) -> Any: ...@@ -41,6 +44,9 @@ def custom_enc_hook(obj: Any) -> Any:
# https://gist.github.com/tlrmchlsmth/8067f1b24a82b6e2f90450e7764fa103 # noqa: E501 # https://gist.github.com/tlrmchlsmth/8067f1b24a82b6e2f90450e7764fa103 # noqa: E501
return msgpack.Ext(CUSTOM_TYPE_TENSOR, pickle.dumps(obj.numpy())) return msgpack.Ext(CUSTOM_TYPE_TENSOR, pickle.dumps(obj.numpy()))
if isinstance(obj, FunctionType):
return msgpack.Ext(CUSTOM_TYPE_CLOUDPICKLE, cloudpickle.dumps(obj))
return msgpack.Ext(CUSTOM_TYPE_PICKLE, pickle.dumps(obj)) return msgpack.Ext(CUSTOM_TYPE_PICKLE, pickle.dumps(obj))
...@@ -49,5 +55,7 @@ def custom_ext_hook(code: int, data: memoryview) -> Any: ...@@ -49,5 +55,7 @@ def custom_ext_hook(code: int, data: memoryview) -> Any:
return torch.from_numpy(pickle.loads(data)) return torch.from_numpy(pickle.loads(data))
if code == CUSTOM_TYPE_PICKLE: if code == CUSTOM_TYPE_PICKLE:
return pickle.loads(data) return pickle.loads(data)
if code == CUSTOM_TYPE_CLOUDPICKLE:
return cloudpickle.loads(data)
raise NotImplementedError(f"Extension type code {code} is not supported") raise NotImplementedError(f"Extension type code {code} is not supported")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment