Unverified Commit ed812a73 authored by Robert Shaw's avatar Robert Shaw Committed by GitHub
Browse files

[ Frontend ] Multiprocessing for OpenAI Server with `zeromq` (#6883)


Signed-off-by: default avatarJoe Runde <Joseph.Runde@ibm.com>
Co-authored-by: default avatarJoe Runde <Joseph.Runde@ibm.com>
Co-authored-by: default avatarJoe Runde <joe@joerun.de>
Co-authored-by: default avatarNick Hill <nickhill@us.ibm.com>
Co-authored-by: default avatarSimon Mo <simon.mo@hey.com>
parent 70898934
This diff is collapsed.
......@@ -7,7 +7,8 @@ from typing import (AsyncIterator, Callable, Dict, Iterable, List, Mapping,
from transformers import PreTrainedTokenizer
import vllm.envs as envs
from vllm.config import DecodingConfig, EngineConfig, ModelConfig
from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig)
from vllm.core.scheduler import SchedulerOutputs
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_timeout import asyncio_timeout
......@@ -928,6 +929,14 @@ class AsyncLLMEngine:
else:
return self.engine.get_model_config()
async def get_parallel_config(self) -> ParallelConfig:
"""Get the parallel configuration of the vLLM engine."""
if self.engine_use_ray:
return await self.engine.get_parallel_config.remote( # type: ignore
)
else:
return self.engine.get_parallel_config()
async def get_decoding_config(self) -> DecodingConfig:
"""Get the decoding configuration of the vLLM engine."""
if self.engine_use_ray:
......@@ -936,6 +945,22 @@ class AsyncLLMEngine:
else:
return self.engine.get_decoding_config()
async def get_scheduler_config(self) -> SchedulerConfig:
"""Get the scheduling configuration of the vLLM engine."""
if self.engine_use_ray:
return await self.engine.get_scheduler_config.remote( # type: ignore
)
else:
return self.engine.get_scheduler_config()
async def get_lora_config(self) -> LoRAConfig:
"""Get the lora configuration of the vLLM engine."""
if self.engine_use_ray:
return await self.engine.get_lora_config.remote( # type: ignore
)
else:
return self.engine.get_lora_config()
async def do_log_stats(
self,
scheduler_outputs: Optional[SchedulerOutputs] = None,
......
......@@ -38,9 +38,8 @@ from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
init_tracer)
from vllm.transformers_utils.config import try_get_generation_config
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer_group import (AnyTokenizer,
BaseTokenizerGroup,
get_tokenizer_group)
from vllm.transformers_utils.tokenizer_group import (
AnyTokenizer, BaseTokenizerGroup, init_tokenizer_from_configs)
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message)
from vllm.utils import Counter
......@@ -485,19 +484,12 @@ class LLMEngine:
return self.get_tokenizer_group().get_lora_tokenizer(
sequence.lora_request)
def _init_tokenizer(self, **tokenizer_init_kwargs) -> BaseTokenizerGroup:
init_kwargs = dict(
tokenizer_id=self.model_config.tokenizer,
enable_lora=bool(self.lora_config),
max_num_seqs=self.scheduler_config.max_num_seqs,
max_input_length=None,
tokenizer_mode=self.model_config.tokenizer_mode,
trust_remote_code=self.model_config.trust_remote_code,
revision=self.model_config.tokenizer_revision)
init_kwargs.update(tokenizer_init_kwargs)
return get_tokenizer_group(self.parallel_config.tokenizer_pool_config,
**init_kwargs)
def _init_tokenizer(self) -> BaseTokenizerGroup:
return init_tokenizer_from_configs(
model_config=self.model_config,
scheduler_config=self.scheduler_config,
parallel_config=self.parallel_config,
enable_lora=bool(self.lora_config))
def _verify_args(self) -> None:
self.model_config.verify_with_parallel_config(self.parallel_config)
......@@ -759,10 +751,22 @@ class LLMEngine:
"""Gets the model configuration."""
return self.model_config
def get_parallel_config(self) -> ParallelConfig:
"""Gets the parallel configuration."""
return self.parallel_config
def get_decoding_config(self) -> DecodingConfig:
"""Gets the decoding configuration."""
return self.decoding_config
def get_scheduler_config(self) -> SchedulerConfig:
"""Gets the scheduler configuration."""
return self.scheduler_config
def get_lora_config(self) -> LoRAConfig:
"""Gets the LoRA configuration."""
return self.lora_config
def get_num_unfinished_requests(self) -> int:
"""Gets the number of unfinished requests."""
return sum(scheduler.get_num_unfinished_seq_groups()
......
from typing import (AsyncIterator, List, Mapping, Optional, Protocol,
runtime_checkable)
from transformers import PreTrainedTokenizer
from vllm.config import DecodingConfig, ModelConfig
from vllm.core.scheduler import SchedulerOutputs
from vllm.inputs.data import PromptInputs
from vllm.lora.request import LoRARequest
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.sequence import SamplerOutput
@runtime_checkable
class AsyncEngineClient(Protocol):
"""Protocol class for Clients to AsyncLLMEngine"""
@property
def is_running(self) -> bool:
...
@property
def is_stopped(self) -> bool:
...
@property
def errored(self) -> bool:
...
async def generate(
self,
inputs: PromptInputs,
sampling_params: SamplingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
) -> AsyncIterator[RequestOutput]:
"""Generates outputs for a request"""
async def encode(
self,
inputs: PromptInputs,
pooling_params: PoolingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
) -> AsyncIterator[EmbeddingRequestOutput]:
"""Generate outputs for a request from an embedding model."""
async def abort(self, request_id: str) -> None:
"""Abort a request.
Args:
request_id: The unique id of the request.
"""
async def get_model_config(self) -> ModelConfig:
"""Get the model configuration of the vLLM engine."""
async def get_decoding_config(self) -> DecodingConfig:
"""Get the decoding configuration of the vLLM engine."""
async def get_tokenizer(
self,
lora_request: Optional[LoRARequest] = None,
) -> PreTrainedTokenizer:
"""Get the appropriate Tokenizer for the request"""
async def is_tracing_enabled(self) -> bool:
pass
async def do_log_stats(
self,
scheduler_outputs: Optional[SchedulerOutputs] = None,
model_output: Optional[List[SamplerOutput]] = None,
) -> None:
pass
async def check_health(self) -> None:
"""Raise if unhealthy"""
......@@ -5,7 +5,8 @@ import re
import signal
from contextlib import asynccontextmanager
from http import HTTPStatus
from typing import Optional, Set
from multiprocessing import Process
from typing import AsyncIterator, Set
import fastapi
import uvicorn
......@@ -17,8 +18,10 @@ from prometheus_client import make_asgi_app
from starlette.routing import Mount
import vllm.envs as envs
from vllm.config import ModelConfig
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.protocol import AsyncEngineClient
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.cli_args import make_arg_parser
# yapf conflicts with isort for this block
......@@ -31,6 +34,8 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
EmbeddingRequest, ErrorResponse,
TokenizeRequest,
TokenizeResponse)
from vllm.entrypoints.openai.rpc.client import AsyncEngineRPCClient
from vllm.entrypoints.openai.rpc.server import run_rpc_server
# yapf: enable
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
......@@ -39,12 +44,12 @@ from vllm.entrypoints.openai.serving_tokenization import (
OpenAIServingTokenization)
from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser
from vllm.utils import FlexibleArgumentParser, get_open_port
from vllm.version import __version__ as VLLM_VERSION
TIMEOUT_KEEP_ALIVE = 5 # seconds
engine: AsyncLLMEngine
async_engine_client: AsyncEngineClient
engine_args: AsyncEngineArgs
openai_serving_chat: OpenAIServingChat
openai_serving_completion: OpenAIServingCompletion
......@@ -56,13 +61,22 @@ logger = init_logger('vllm.entrypoints.openai.api_server')
_running_tasks: Set[asyncio.Task] = set()
def model_is_embedding(model_name: str) -> bool:
return ModelConfig(model=model_name,
tokenizer=model_name,
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
dtype="float16").embedding_mode
@asynccontextmanager
async def lifespan(app: fastapi.FastAPI):
async def _force_log():
while True:
await asyncio.sleep(10)
await engine.do_log_stats()
await async_engine_client.do_log_stats()
if not engine_args.disable_log_stats:
task = asyncio.create_task(_force_log())
......@@ -72,6 +86,52 @@ async def lifespan(app: fastapi.FastAPI):
yield
@asynccontextmanager
async def build_async_engine_client(args) -> AsyncIterator[AsyncEngineClient]:
# Context manager to handle async_engine_client lifecycle
# Ensures everything is shutdown and cleaned up on error/exit
global engine_args
engine_args = AsyncEngineArgs.from_cli_args(args)
# Backend itself still global for the silly lil' health handler
global async_engine_client
# If manually triggered or embedding model, use AsyncLLMEngine in process.
# TODO: support embedding model via RPC.
if (model_is_embedding(args.model)
or args.disable_frontend_multiprocessing):
async_engine_client = AsyncLLMEngine.from_engine_args(
engine_args, usage_context=UsageContext.OPENAI_API_SERVER)
yield async_engine_client
return
# Otherwise, use the multiprocessing AsyncLLMEngine.
else:
# Start RPCServer in separate process (holds the AsyncLLMEngine).
port = get_open_port(envs.VLLM_RPC_PORT)
rpc_server_process = Process(target=run_rpc_server,
args=(engine_args,
UsageContext.OPENAI_API_SERVER,
port))
rpc_server_process.start()
# Build RPCClient, which conforms to AsyncEngineClient Protocol.
async_engine_client = AsyncEngineRPCClient(port)
await async_engine_client.setup()
try:
yield async_engine_client
finally:
# Ensure rpc server process was terminated
rpc_server_process.terminate()
# Close all open connections to the backend
async_engine_client.close()
# Wait for server process to join
rpc_server_process.join()
router = APIRouter()
......@@ -86,7 +146,7 @@ def mount_metrics(app: fastapi.FastAPI):
@router.get("/health")
async def health() -> Response:
"""Health check."""
await openai_serving_chat.engine.check_health()
await async_engine_client.check_health()
return Response(status_code=200)
......@@ -215,8 +275,8 @@ def build_app(args):
async def build_server(
async_engine_client: AsyncEngineClient,
args,
llm_engine: Optional[AsyncLLMEngine] = None,
**uvicorn_kwargs,
) -> uvicorn.Server:
app = build_app(args)
......@@ -226,14 +286,7 @@ async def build_server(
else:
served_model_names = [args.model]
global engine, engine_args
engine_args = AsyncEngineArgs.from_cli_args(args)
engine = (llm_engine
if llm_engine is not None else AsyncLLMEngine.from_engine_args(
engine_args, usage_context=UsageContext.OPENAI_API_SERVER))
model_config = await engine.get_model_config()
model_config = await async_engine_client.get_model_config()
if args.disable_log_requests:
request_logger = None
......@@ -246,7 +299,7 @@ async def build_server(
global openai_serving_tokenization
openai_serving_chat = OpenAIServingChat(
engine,
async_engine_client,
model_config,
served_model_names,
args.response_role,
......@@ -257,7 +310,7 @@ async def build_server(
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
)
openai_serving_completion = OpenAIServingCompletion(
engine,
async_engine_client,
model_config,
served_model_names,
lora_modules=args.lora_modules,
......@@ -266,13 +319,13 @@ async def build_server(
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
)
openai_serving_embedding = OpenAIServingEmbedding(
engine,
async_engine_client,
model_config,
served_model_names,
request_logger=request_logger,
)
openai_serving_tokenization = OpenAIServingTokenization(
engine,
async_engine_client,
model_config,
served_model_names,
lora_modules=args.lora_modules,
......@@ -304,13 +357,16 @@ async def build_server(
return uvicorn.Server(config)
async def run_server(args, llm_engine=None, **uvicorn_kwargs) -> None:
async def run_server(args, **uvicorn_kwargs) -> None:
logger.info("vLLM API server version %s", VLLM_VERSION)
logger.info("args: %s", args)
shutdown_task = None
async with build_async_engine_client(args) as async_engine_client:
server = await build_server(
async_engine_client,
args,
llm_engine,
**uvicorn_kwargs,
)
......@@ -328,8 +384,12 @@ async def run_server(args, llm_engine=None, **uvicorn_kwargs) -> None:
try:
await server_task
except asyncio.CancelledError:
print("Gracefully stopping http server")
await server.shutdown()
logger.info("Gracefully stopping http server")
shutdown_task = server.shutdown()
if shutdown_task:
# NB: Await server shutdown only after the backend context is exited
await shutdown_task
if __name__ == "__main__":
......
......@@ -131,9 +131,14 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
parser.add_argument(
"--return-tokens-as-token-ids",
action="store_true",
help="When --max-logprobs is specified, represents single tokens as"
"strings of the form 'token_id:{token_id}' so that tokens that"
help="When --max-logprobs is specified, represents single tokens as "
"strings of the form 'token_id:{token_id}' so that tokens that "
"are not JSON-encodable can be identified.")
parser.add_argument(
"--disable-frontend-multiprocessing",
action="store_true",
help="If specified, will run the OpenAI frontend server in the same "
"process as the model serving engine.")
parser = AsyncEngineArgs.add_cli_args(parser)
......
from functools import lru_cache
from functools import lru_cache, partial
from typing import Dict, FrozenSet, Iterable, List, Optional, Union
import torch
......@@ -40,6 +40,14 @@ def _get_allowed_token_ids_logits_processor(
return AllowedTokenIdsLogitsProcessor(allowed_token_ids)
def logit_bias_logits_processor(logit_bias: Dict[str,
float], token_ids: List[int],
logits: torch.Tensor) -> torch.Tensor:
for token_id, bias in logit_bias.items():
logits[token_id] += bias
return logits
def get_logits_processors(
logit_bias: Optional[Union[Dict[int, float], Dict[str, float]]],
allowed_token_ids: Optional[List[int]],
......@@ -64,13 +72,8 @@ def get_logits_processors(
raise ValueError("token_id in logit_bias contains "
"out-of-vocab token id")
def logit_bias_logits_processor(token_ids: List[int],
logits: torch.Tensor) -> torch.Tensor:
for token_id, bias in clamped_logit_bias.items():
logits[token_id] += bias
return logits
logits_processors.append(logit_bias_logits_processor)
logits_processors.append(
partial(logit_bias_logits_processor, clamped_logit_bias))
if allowed_token_ids is not None:
logits_processors.append(
......
from dataclasses import dataclass
from enum import Enum
from typing import Mapping, Optional, Union
from vllm.inputs import PromptInputs
from vllm.lora.request import LoRARequest
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
VLLM_RPC_SUCCESS_STR = "SUCCESS"
VLLM_RPC_HEALTHY_STR = "HEALTHY"
@dataclass
class RPCGenerateRequest:
inputs: PromptInputs
sampling_params: SamplingParams
request_id: str
lora_request: Optional[LoRARequest] = None
trace_headers: Optional[Mapping[str, str]] = None
prompt_adapter_request: Optional[PromptAdapterRequest] = None
@dataclass
class RPCAbortRequest:
request_id: str
class RPCUtilityRequest(Enum):
IS_SERVER_READY = 1
GET_MODEL_CONFIG = 2
GET_DECODING_CONFIG = 3
GET_PARALLEL_CONFIG = 4
GET_SCHEDULER_CONFIG = 5
GET_LORA_CONFIG = 6
DO_LOG_STATS = 7
CHECK_HEALTH = 8
IS_TRACING_ENABLED = 9
RPC_REQUEST_TYPE = Union[RPCGenerateRequest, RPCAbortRequest,
RPCUtilityRequest]
from contextlib import contextmanager
from typing import Any, AsyncIterator, Mapping, Optional
import cloudpickle
import zmq
import zmq.asyncio
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig)
from vllm.entrypoints.openai.rpc import (RPC_REQUEST_TYPE,
VLLM_RPC_HEALTHY_STR,
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
RPCGenerateRequest, RPCUtilityRequest)
from vllm.inputs import PromptInputs
from vllm.lora.request import LoRARequest
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
class AsyncEngineRPCClient:
def __init__(self, port: int):
self.context = zmq.asyncio.Context()
self.path = f"tcp://localhost:{port}"
async def setup(self):
"""Setup the client before it starts sending server requests."""
# Wait until server is ready.
await self.wait_for_server()
# Get the configs.
self.model_config = await self._get_model_config_rpc()
self.decoding_config = await self._get_decoding_config_rpc()
self.tracing_flag = await self._is_tracing_enabled_rpc()
# Create the tokenizer group.
# TODO: refactor OAI server to avoid needing this info.
self.tokenizer = init_tokenizer_from_configs(
model_config=self.model_config,
scheduler_config=(await self._get_scheduler_config_rpc()),
parallel_config=(await self._get_parallel_config_rpc()),
enable_lora=bool(await self._get_lora_config_rpc()),
)
def close(self):
"""Destroy the ZeroMQ Context."""
self.context.destroy()
@contextmanager
def socket(self):
# Ensure client sockets are always closed after use
# Connect to RPC socket for Request-Reply pattern,
# Note that we use DEALER to enable asynchronous communication
# to enable streaming.
socket = self.context.socket(zmq.constants.DEALER)
try:
socket.connect(self.path)
yield socket
finally:
socket.close()
async def _send_get_data_rpc_request(self, request: RPCUtilityRequest,
expected_type: Any,
error_message: str) -> Any:
"""Send an RPC request that is expecting data back."""
with self.socket() as socket:
# Ping RPCServer with a request.
await socket.send(cloudpickle.dumps(request))
# Await the data from the Server.
data = cloudpickle.loads(await socket.recv())
if not isinstance(data, expected_type):
# LoRAConfig can be None.
if expected_type == LoRAConfig and data is None:
pass
else:
raise ValueError(error_message)
return data
async def _send_one_way_rpc_request(self, request: RPC_REQUEST_TYPE,
error_message: str):
"""Send one-way RPC request to trigger an action."""
with self.socket() as socket:
# Ping RPC Server with request.
await socket.send(cloudpickle.dumps(request))
# Await acknowledgement from RPCServer.
response = cloudpickle.loads(await socket.recv())
if not isinstance(response, str) or response != VLLM_RPC_SUCCESS_STR:
raise ValueError(error_message)
return response
async def get_tokenizer(self, lora_request: LoRARequest):
return await self.tokenizer.get_lora_tokenizer_async(lora_request)
async def get_decoding_config(self) -> DecodingConfig:
return self.decoding_config
async def get_model_config(self) -> ModelConfig:
return self.model_config
async def is_tracing_enabled(self) -> bool:
return self.tracing_flag
async def wait_for_server(self):
"""Wait for the RPCServer to start up."""
await self._send_one_way_rpc_request(
request=RPCUtilityRequest.IS_SERVER_READY,
error_message="Unable to start RPC Server.")
async def _get_model_config_rpc(self) -> ModelConfig:
"""Get the ModelConfig object from the RPC Server"""
return await self._send_get_data_rpc_request(
RPCUtilityRequest.GET_MODEL_CONFIG,
expected_type=ModelConfig,
error_message="Could not get ModelConfig from RPC Server")
async def _get_decoding_config_rpc(self) -> DecodingConfig:
"""Get DecodingConfig from the RPCServer"""
return await self._send_get_data_rpc_request(
RPCUtilityRequest.GET_DECODING_CONFIG,
expected_type=DecodingConfig,
error_message="Could not get DecodingConfig from RPC Server")
async def _get_parallel_config_rpc(self) -> ParallelConfig:
"""Get ParallelConfig from the RPCServer"""
return await self._send_get_data_rpc_request(
RPCUtilityRequest.GET_PARALLEL_CONFIG,
expected_type=ParallelConfig,
error_message="Could not get ParallelConfig from RPC Server")
async def _get_scheduler_config_rpc(self) -> SchedulerConfig:
"""Get SchedulerConfig from the RPCServer"""
return await self._send_get_data_rpc_request(
RPCUtilityRequest.GET_SCHEDULER_CONFIG,
expected_type=SchedulerConfig,
error_message="Could not get SchedulerConfig from RPC Server")
async def _get_lora_config_rpc(self):
"""Get LoRAConfig from the RPCServer"""
return await self._send_get_data_rpc_request(
RPCUtilityRequest.GET_LORA_CONFIG,
expected_type=LoRAConfig,
error_message="Could not get LoRAConfig from RPC Server")
async def _is_tracing_enabled_rpc(self) -> ParallelConfig:
"""Get is_tracing_enabled flag from the RPCServer"""
return await self._send_get_data_rpc_request(
RPCUtilityRequest.IS_TRACING_ENABLED,
expected_type=bool,
error_message="Could not get is_tracing_enabled flag from RPC "
"Server")
async def abort(self, request_id: str):
"""Send an ABORT_REQUEST signal to the RPC Server"""
await self._send_one_way_rpc_request(
request=RPCAbortRequest(request_id),
error_message=f"RPCAbortRequest {request_id} failed")
async def do_log_stats(self):
"""Send a DO_LOG_STATS signal to the RPC Server"""
await self._send_one_way_rpc_request(
request=RPCUtilityRequest.DO_LOG_STATS,
error_message="RPCRequest DO_LOG_STATS failed.")
async def generate(
self,
inputs: PromptInputs,
sampling_params: SamplingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
) -> AsyncIterator[RequestOutput]:
"""Send an RPCGenerateRequest to the RPCServer and stream responses."""
with self.socket() as socket:
# Send RPCGenerateRequest to the RPCServer.
await socket.send_multipart([
cloudpickle.dumps(
RPCGenerateRequest(
inputs=inputs,
sampling_params=sampling_params,
request_id=request_id,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request))
])
# Stream back the results from the RPC Server.
while True:
message = await socket.recv()
request_output = cloudpickle.loads(message)
if isinstance(request_output, Exception):
raise request_output
if request_output.finished:
break
yield request_output
yield request_output
async def check_health(self) -> None:
"""Raise if unhealthy"""
with self.socket() as socket:
# Ping RPCServer with CHECK_HEALTH request.
await socket.send(cloudpickle.dumps(RPCUtilityRequest.CHECK_HEALTH)
)
# Await the reply from the server.
# TODO: do we need an internal timeout here?
# Or do we expect the external probe to timeout and let this chill?
health_message = cloudpickle.loads(await socket.recv())
if isinstance(health_message, Exception):
raise health_message
if health_message != VLLM_RPC_HEALTHY_STR:
raise ValueError("Expected healthy response from backend but got "
"f{health_message}")
async def encode(self, *args,
**kwargs) -> AsyncIterator[EmbeddingRequestOutput]:
raise NotImplementedError(
"Embeddings not supported with multiprocessing backend")
import asyncio
import signal
from typing import Any, Coroutine
import cloudpickle
import zmq
import zmq.asyncio
from typing_extensions import Never
from vllm import AsyncEngineArgs, AsyncLLMEngine
from vllm.entrypoints.openai.rpc import (VLLM_RPC_HEALTHY_STR,
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
RPCGenerateRequest, RPCUtilityRequest)
from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext
logger = init_logger(__name__)
class AsyncEngineRPCServer:
def __init__(self, async_engine_args: AsyncEngineArgs,
usage_context: UsageContext, port: int):
# Initialize engine first.
self.engine = AsyncLLMEngine.from_engine_args(async_engine_args,
usage_context)
# Initialize context.
self.context = zmq.asyncio.Context()
# Init socket for readiness state.
self.socket = self.context.socket(zmq.constants.ROUTER)
self.socket.bind(f"tcp://localhost:{port}")
def cleanup(self):
"""Cleanup all resources."""
self.socket.close()
self.context.destroy()
async def get_model_config(self, identity):
"""Send the ModelConfig"""
model_config = await self.engine.get_model_config()
await self.socket.send_multipart(
[identity, cloudpickle.dumps(model_config)])
async def get_decoding_config(self, identity):
"""Send the DecodingConfig"""
decoding_config = await self.engine.get_decoding_config()
await self.socket.send_multipart(
[identity, cloudpickle.dumps(decoding_config)])
async def get_lora_config(self, identity):
lora_config = await self.engine.get_lora_config()
await self.socket.send_multipart(
[identity, cloudpickle.dumps(lora_config)])
async def get_scheduler_config(self, identity):
"""Send the SchedulerConfig"""
parallel_config = await self.engine.get_scheduler_config()
await self.socket.send_multipart(
[identity, cloudpickle.dumps(parallel_config)])
async def get_parallel_config(self, identity):
"""Send the ParallelConfig"""
parallel_config = await self.engine.get_parallel_config()
await self.socket.send_multipart(
[identity, cloudpickle.dumps(parallel_config)])
async def is_tracing_enabled(self, identity):
"""Send the is_tracing_enabled flag"""
tracing_flag = await self.engine.is_tracing_enabled()
await self.socket.send_multipart(
[identity, cloudpickle.dumps(tracing_flag)])
async def do_log_stats(self, identity):
"""Log stats and confirm success."""
await self.engine.do_log_stats()
await self.socket.send_multipart([
identity,
cloudpickle.dumps(VLLM_RPC_SUCCESS_STR),
])
async def is_server_ready(self, identity):
"""Notify the client that we are ready."""
await self.socket.send_multipart([
identity,
cloudpickle.dumps(VLLM_RPC_SUCCESS_STR),
])
async def abort(self, identity, request: RPCAbortRequest):
"""Abort request and notify the client of success."""
# Abort the request in the llm engine.
await self.engine.abort(request.request_id)
# Send confirmation to the client.
await self.socket.send_multipart([
identity,
cloudpickle.dumps(VLLM_RPC_SUCCESS_STR),
])
async def generate(self, identity, generate_request: RPCGenerateRequest):
try:
results_generator = self.engine.generate(
generate_request.inputs,
sampling_params=generate_request.sampling_params,
request_id=generate_request.request_id,
lora_request=generate_request.lora_request,
trace_headers=generate_request.trace_headers,
prompt_adapter_request=generate_request.prompt_adapter_request)
async for request_output in results_generator:
await self.socket.send_multipart(
[identity, cloudpickle.dumps(request_output)])
except Exception as e:
### Notify client of all failures
await self.socket.send_multipart([identity, cloudpickle.dumps(e)])
async def check_health(self, identity):
try:
await self.engine.check_health()
await self.socket.send_multipart(
[identity, cloudpickle.dumps(VLLM_RPC_HEALTHY_STR)])
except Exception as e:
await self.socket.send_multipart([identity, cloudpickle.dumps(e)])
def _make_handler_coro(self, identity,
message) -> Coroutine[Any, Any, Never]:
"""Route the zmq message to the handler coroutine."""
request = cloudpickle.loads(message)
if isinstance(request, RPCGenerateRequest):
return self.generate(identity, request)
elif isinstance(request, RPCAbortRequest):
return self.abort(identity, request)
elif isinstance(request, RPCUtilityRequest):
if request == RPCUtilityRequest.GET_MODEL_CONFIG:
return self.get_model_config(identity)
elif request == RPCUtilityRequest.GET_PARALLEL_CONFIG:
return self.get_parallel_config(identity)
elif request == RPCUtilityRequest.GET_DECODING_CONFIG:
return self.get_decoding_config(identity)
elif request == RPCUtilityRequest.GET_SCHEDULER_CONFIG:
return self.get_scheduler_config(identity)
elif request == RPCUtilityRequest.GET_LORA_CONFIG:
return self.get_lora_config(identity)
elif request == RPCUtilityRequest.DO_LOG_STATS:
return self.do_log_stats(identity)
elif request == RPCUtilityRequest.IS_SERVER_READY:
return self.is_server_ready(identity)
elif request == RPCUtilityRequest.CHECK_HEALTH:
return self.check_health(identity)
elif request == RPCUtilityRequest.IS_TRACING_ENABLED:
return self.is_tracing_enabled(identity)
else:
raise ValueError(f"Unknown RPCUtilityRequest type: {request}")
else:
raise ValueError(f"Unknown RPCRequest type: {request}")
async def run_server_loop(self):
"""Inner RPC Server Loop"""
running_tasks = set()
while True:
# Wait for a request.
identity, message = await self.socket.recv_multipart()
# Process the request async.
task = asyncio.create_task(
self._make_handler_coro(identity, message))
# We need to keep around a strong reference to the task,
# to avoid the task disappearing mid-execution as running tasks
# can be GC'ed. Below is a common "fire-and-forget" tasks
# https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task
running_tasks.add(task)
task.add_done_callback(running_tasks.discard)
async def run_server(server: AsyncEngineRPCServer):
# Put the server task into the asyncio loop.
loop = asyncio.get_running_loop()
server_task = loop.create_task(server.run_server_loop())
# Interruption handling.
def signal_handler() -> None:
# Kill the server on interrupt / terminate
server_task.cancel()
loop.add_signal_handler(signal.SIGINT, signal_handler)
loop.add_signal_handler(signal.SIGTERM, signal_handler)
try:
await server_task
except asyncio.CancelledError:
logger.info("vLLM ZMQ RPC Server was interrupted.")
finally:
# Clean up all resources.
server.cleanup()
def run_rpc_server(async_engine_args: AsyncEngineArgs,
usage_context: UsageContext, port: int):
server = AsyncEngineRPCServer(async_engine_args, usage_context, port)
asyncio.run(run_server(server))
......@@ -8,7 +8,7 @@ from fastapi import Request
from transformers import PreTrainedTokenizer
from vllm.config import ModelConfig
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.protocol import AsyncEngineClient
from vllm.entrypoints.chat_utils import (ConversationMessage,
load_chat_template,
parse_chat_message_content)
......@@ -39,7 +39,7 @@ class OpenAIServingChat(OpenAIServing):
def __init__(
self,
engine: AsyncLLMEngine,
async_engine_client: AsyncEngineClient,
model_config: ModelConfig,
served_model_names: List[str],
response_role: str,
......@@ -50,7 +50,7 @@ class OpenAIServingChat(OpenAIServing):
chat_template: Optional[str],
return_tokens_as_token_ids: bool = False,
):
super().__init__(engine=engine,
super().__init__(async_engine_client=async_engine_client,
model_config=model_config,
served_model_names=served_model_names,
lora_modules=lora_modules,
......@@ -89,7 +89,8 @@ class OpenAIServingChat(OpenAIServing):
) = self._maybe_get_adapters(request)
model_config = self.model_config
tokenizer = await self.engine.get_tokenizer(lora_request)
tokenizer = await self.async_engine_client.get_tokenizer(
lora_request)
conversation: List[ConversationMessage] = []
mm_futures: List[Awaitable[MultiModalDataDict]] = []
......@@ -161,7 +162,8 @@ class OpenAIServingChat(OpenAIServing):
if mm_data is not None:
engine_inputs["multi_modal_data"] = mm_data
is_tracing_enabled = await self.engine.is_tracing_enabled()
is_tracing_enabled = (
await self.async_engine_client.is_tracing_enabled())
trace_headers = None
if is_tracing_enabled and raw_request:
trace_headers = extract_trace_headers(raw_request.headers)
......@@ -169,7 +171,7 @@ class OpenAIServingChat(OpenAIServing):
and contains_trace_headers(raw_request.headers)):
log_tracing_disabled_warning()
result_generator = self.engine.generate(
result_generator = self.async_engine_client.generate(
engine_inputs,
sampling_params,
request_id,
......@@ -441,7 +443,7 @@ class OpenAIServingChat(OpenAIServing):
async for res in result_generator:
if raw_request is not None and await raw_request.is_disconnected():
# Abort the request if the client disconnects.
await self.engine.abort(request_id)
await self.async_engine_client.abort(request_id)
return self.create_error_response("Client disconnected")
final_res = res
assert final_res is not None
......
......@@ -8,7 +8,7 @@ from fastapi import Request
from transformers import PreTrainedTokenizer
from vllm.config import ModelConfig
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.protocol import AsyncEngineClient
from vllm.entrypoints.logger import RequestLogger
# yapf conflicts with isort for this block
# yapf: disable
......@@ -42,7 +42,7 @@ class OpenAIServingCompletion(OpenAIServing):
def __init__(
self,
engine: AsyncLLMEngine,
async_engine_client: AsyncEngineClient,
model_config: ModelConfig,
served_model_names: List[str],
*,
......@@ -51,7 +51,7 @@ class OpenAIServingCompletion(OpenAIServing):
request_logger: Optional[RequestLogger],
return_tokens_as_token_ids: bool = False,
):
super().__init__(engine=engine,
super().__init__(async_engine_client=async_engine_client,
model_config=model_config,
served_model_names=served_model_names,
lora_modules=lora_modules,
......@@ -91,7 +91,8 @@ class OpenAIServingCompletion(OpenAIServing):
prompt_adapter_request,
) = self._maybe_get_adapters(request)
tokenizer = await self.engine.get_tokenizer(lora_request)
tokenizer = await self.async_engine_client.get_tokenizer(
lora_request)
guided_decode_logits_processor = (
await self._guided_decode_logits_processor(request, tokenizer))
......@@ -119,7 +120,8 @@ class OpenAIServingCompletion(OpenAIServing):
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
is_tracing_enabled = await self.engine.is_tracing_enabled()
is_tracing_enabled = (
await self.async_engine_client.is_tracing_enabled())
trace_headers = None
if is_tracing_enabled:
trace_headers = extract_trace_headers(raw_request.headers)
......@@ -127,7 +129,7 @@ class OpenAIServingCompletion(OpenAIServing):
raw_request.headers):
log_tracing_disabled_warning()
generator = self.engine.generate(
generator = self.async_engine_client.generate(
{"prompt_token_ids": prompt_inputs["prompt_token_ids"]},
sampling_params,
request_id_item,
......@@ -168,7 +170,7 @@ class OpenAIServingCompletion(OpenAIServing):
async for i, res in result_generator:
if await raw_request.is_disconnected():
# Abort the request if the client disconnects.
await self.engine.abort(f"{request_id}-{i}")
await self.async_engine_client.abort(f"{request_id}-{i}")
return self.create_error_response("Client disconnected")
final_res_batch[i] = res
......@@ -230,7 +232,8 @@ class OpenAIServingCompletion(OpenAIServing):
# Abort the request if the client disconnects.
if await raw_request.is_disconnected():
await self.engine.abort(f"{request_id}-{prompt_idx}")
await self.async_engine_client.abort(
f"{request_id}-{prompt_idx}")
raise StopAsyncIteration()
for output in res.outputs:
......
......@@ -6,7 +6,7 @@ import numpy as np
from fastapi import Request
from vllm.config import ModelConfig
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.protocol import AsyncEngineClient
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (EmbeddingRequest,
EmbeddingResponse,
......@@ -56,13 +56,13 @@ class OpenAIServingEmbedding(OpenAIServing):
def __init__(
self,
engine: AsyncLLMEngine,
async_engine_client: AsyncEngineClient,
model_config: ModelConfig,
served_model_names: List[str],
*,
request_logger: Optional[RequestLogger],
):
super().__init__(engine=engine,
super().__init__(async_engine_client=async_engine_client,
model_config=model_config,
served_model_names=served_model_names,
lora_modules=None,
......@@ -99,7 +99,8 @@ class OpenAIServingEmbedding(OpenAIServing):
prompt_adapter_request,
) = self._maybe_get_adapters(request)
tokenizer = await self.engine.get_tokenizer(lora_request)
tokenizer = await self.async_engine_client.get_tokenizer(
lora_request)
pooling_params = request.to_pooling_params()
......@@ -124,7 +125,7 @@ class OpenAIServingEmbedding(OpenAIServing):
"Prompt adapter is not supported "
"for embedding models")
generator = self.engine.encode(
generator = self.async_engine_client.encode(
{"prompt_token_ids": prompt_inputs["prompt_token_ids"]},
pooling_params,
request_id_item,
......@@ -146,7 +147,7 @@ class OpenAIServingEmbedding(OpenAIServing):
async for i, res in result_generator:
if await raw_request.is_disconnected():
# Abort the request if the client disconnects.
await self.engine.abort(f"{request_id}-{i}")
await self.async_engine_client.abort(f"{request_id}-{i}")
return self.create_error_response("Client disconnected")
final_res_batch[i] = res
......
......@@ -8,7 +8,7 @@ from pydantic import Field
from typing_extensions import Annotated
from vllm.config import ModelConfig
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.protocol import AsyncEngineClient
from vllm.entrypoints.logger import RequestLogger
# yapf conflicts with isort for this block
# yapf: disable
......@@ -61,7 +61,7 @@ class OpenAIServing:
def __init__(
self,
engine: AsyncLLMEngine,
async_engine_client: AsyncEngineClient,
model_config: ModelConfig,
served_model_names: List[str],
*,
......@@ -72,7 +72,7 @@ class OpenAIServing:
):
super().__init__()
self.engine = engine
self.async_engine_client = async_engine_client
self.model_config = model_config
self.max_model_len = model_config.max_model_len
......@@ -155,7 +155,7 @@ class OpenAIServing:
async def _guided_decode_logits_processor(
self, request: Union[ChatCompletionRequest, CompletionRequest],
tokenizer: AnyTokenizer) -> Optional[LogitsProcessor]:
decoding_config = await self.engine.get_decoding_config()
decoding_config = await self.async_engine_client.get_decoding_config()
guided_decoding_backend = request.guided_decoding_backend \
or decoding_config.guided_decoding_backend
return await get_guided_decoding_logits_processor(
......
from typing import List, Optional, Union
from vllm.config import ModelConfig
from vllm.engine.async_llm_engine import AsyncLLMEngine
# yapf conflicts with isort for this block
# yapf: disable
from vllm.engine.protocol import AsyncEngineClient
from vllm.entrypoints.chat_utils import (ConversationMessage,
load_chat_template,
parse_chat_message_content)
......@@ -24,7 +24,7 @@ class OpenAIServingTokenization(OpenAIServing):
def __init__(
self,
engine: AsyncLLMEngine,
async_engine_client: AsyncEngineClient,
model_config: ModelConfig,
served_model_names: List[str],
*,
......@@ -32,7 +32,7 @@ class OpenAIServingTokenization(OpenAIServing):
request_logger: Optional[RequestLogger],
chat_template: Optional[str],
):
super().__init__(engine=engine,
super().__init__(async_engine_client=async_engine_client,
model_config=model_config,
served_model_names=served_model_names,
lora_modules=lora_modules,
......@@ -57,7 +57,7 @@ class OpenAIServingTokenization(OpenAIServing):
prompt_adapter_request,
) = self._maybe_get_adapters(request)
tokenizer = await self.engine.get_tokenizer(lora_request)
tokenizer = await self.async_engine_client.get_tokenizer(lora_request)
if isinstance(request, TokenizeChatRequest):
model_config = self.model_config
......@@ -113,7 +113,7 @@ class OpenAIServingTokenization(OpenAIServing):
prompt_adapter_request,
) = self._maybe_get_adapters(request)
tokenizer = await self.engine.get_tokenizer(lora_request)
tokenizer = await self.async_engine_client.get_tokenizer(lora_request)
self._log_inputs(request_id,
request.tokens,
......
......@@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, Optional
if TYPE_CHECKING:
VLLM_HOST_IP: str = ""
VLLM_PORT: Optional[int] = None
VLLM_RPC_PORT: int = 5570
VLLM_USE_MODELSCOPE: bool = False
VLLM_RINGBUFFER_WARNING_INTERVAL: int = 60
VLLM_INSTANCE_ID: Optional[str] = None
......@@ -140,6 +141,11 @@ environment_variables: Dict[str, Callable[[], Any]] = {
lambda: int(os.getenv('VLLM_PORT', '0'))
if 'VLLM_PORT' in os.environ else None,
# used when the frontend api server is running in multi-processing mode,
# to communicate with the backend engine process over ZMQ.
'VLLM_RPC_PORT':
lambda: int(os.getenv('VLLM_PORT', '5570')),
# If true, will load models from ModelScope instead of Hugging Face Hub.
# note that the value is true or false, not numbers
"VLLM_USE_MODELSCOPE":
......
......@@ -21,6 +21,8 @@ from functools import lru_cache
from typing import Callable, DefaultDict, Dict, List, Union
import torch
from lark import Lark
from outlines import grammars
from outlines.caching import cache
from outlines.fsm.guide import CFGGuide, Generate, Guide, RegexGuide, Write
from outlines.fsm.json_schema import build_regex_from_schema
......@@ -44,6 +46,23 @@ class BaseLogitsProcessor:
last_seq_id = hash(tuple(input_ids[:-1]))
self._fsm_state[seq_id] = self._guide.get_next_state(
state=self._fsm_state[last_seq_id], token_id=last_token)
else:
# Note: this is a hack.
# Lark pickling does not work properly (silent failure),
# which breaks the RPC (which uses python pickleing).
# We need to find a better solution.
# On the first time this is called, we simply re-create
# the Lark object.
if isinstance(self._guide, CFGGuide):
self._guide.parser = Lark(
self._guide.cfg_string,
parser="lalr",
lexer="contextual",
propagate_positions=False,
maybe_placeholders=False,
regex=True,
import_paths=[grammars.GRAMMAR_PATH],
)
instruction = self._guide.get_next_instruction(
state=self._fsm_state[seq_id])
......
......@@ -60,7 +60,7 @@ def get_span_exporter(endpoint):
OTLPSpanExporter)
elif protocol == "http/protobuf":
from opentelemetry.exporter.otlp.proto.http.trace_exporter import (
OTLPSpanExporter)
OTLPSpanExporter) # type: ignore
else:
raise ValueError(
f"Unsupported OTLP protocol '{protocol}' is configured")
......
from typing import Optional, Type
from vllm.config import TokenizerPoolConfig
from vllm.config import (ModelConfig, ParallelConfig, SchedulerConfig,
TokenizerPoolConfig)
from vllm.executor.ray_utils import ray
from .base_tokenizer_group import AnyTokenizer, BaseTokenizerGroup
......@@ -13,6 +14,22 @@ else:
RayTokenizerGroupPool = None # type: ignore
def init_tokenizer_from_configs(model_config: ModelConfig,
scheduler_config: SchedulerConfig,
parallel_config: ParallelConfig,
enable_lora: bool):
init_kwargs = dict(tokenizer_id=model_config.tokenizer,
enable_lora=enable_lora,
max_num_seqs=scheduler_config.max_num_seqs,
max_input_length=None,
tokenizer_mode=model_config.tokenizer_mode,
trust_remote_code=model_config.trust_remote_code,
revision=model_config.tokenizer_revision)
return get_tokenizer_group(parallel_config.tokenizer_pool_config,
**init_kwargs)
def get_tokenizer_group(tokenizer_pool_config: Optional[TokenizerPoolConfig],
**init_kwargs) -> BaseTokenizerGroup:
tokenizer_cls: Type[BaseTokenizerGroup]
......
......@@ -290,6 +290,10 @@ def make_async(func: Callable[P, T]) -> Callable[P, Awaitable[T]]:
return _async_wrapper
class ProducerFinished:
pass
def merge_async_iterators(
*iterators: AsyncIterator[T]) -> AsyncIterator[Tuple[int, T]]:
"""Merge multiple asynchronous iterators into a single iterator.
......@@ -298,9 +302,10 @@ def merge_async_iterators(
When it yields, it yields a tuple (i, item) where i is the index of the
iterator that yields the item.
"""
queue: asyncio.Queue[Union[Tuple[int, T], Exception]] = asyncio.Queue()
queue: asyncio.Queue[Union[Tuple[int, T], ProducerFinished,
Exception]] = asyncio.Queue()
finished = [False] * len(iterators)
producers = len(iterators)
async def producer(i: int, iterator: AsyncIterator[T]):
try:
......@@ -308,7 +313,8 @@ def merge_async_iterators(
await queue.put((i, item))
except Exception as e:
await queue.put(e)
finished[i] = True
# Signal to the consumer that we've finished
await queue.put(ProducerFinished())
_tasks = [
asyncio.create_task(producer(i, iterator))
......@@ -316,9 +322,17 @@ def merge_async_iterators(
]
async def consumer():
remaining = producers
try:
while not all(finished) or not queue.empty():
while remaining or not queue.empty():
# we think there is a race condition here
item = await queue.get()
if isinstance(item, ProducerFinished):
# Signal that a producer finished- not a real item
remaining -= 1
continue
if isinstance(item, Exception):
raise item
yield item
......@@ -374,7 +388,9 @@ def get_distributed_init_method(ip: str, port: int) -> str:
return f"tcp://[{ip}]:{port}" if ":" in ip else f"tcp://{ip}:{port}"
def get_open_port() -> int:
def get_open_port(port: Optional[int] = None) -> int:
if port is None:
# Default behavior here is to return a port for multi-gpu communication
port = envs.VLLM_PORT
if port is not None:
while True:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment