Commit 539aa992 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.6.2' into v0.6.2-dev

parents 93872128 7193774b
...@@ -21,11 +21,12 @@ If you only need to use the distributed environment without model/pipeline ...@@ -21,11 +21,12 @@ If you only need to use the distributed environment without model/pipeline
""" """
import contextlib import contextlib
import pickle import pickle
import weakref
from collections import namedtuple from collections import namedtuple
from contextlib import contextmanager, nullcontext from contextlib import contextmanager, nullcontext
from dataclasses import dataclass from dataclasses import dataclass
from multiprocessing import shared_memory from multiprocessing import shared_memory
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from unittest.mock import patch from unittest.mock import patch
import torch import torch
...@@ -34,6 +35,8 @@ from torch.distributed import Backend, ProcessGroup ...@@ -34,6 +35,8 @@ from torch.distributed import Backend, ProcessGroup
import vllm.envs as envs import vllm.envs as envs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import supports_custom_op
@dataclass @dataclass
...@@ -69,6 +72,59 @@ def _split_tensor_dict( ...@@ -69,6 +72,59 @@ def _split_tensor_dict(
return metadata_list, tensor_list return metadata_list, tensor_list
_group_name_counter: Dict[str, int] = {}
def _get_unique_name(name: str) -> str:
"""Get a unique name for the group.
Example:
_get_unique_name("tp") -> "tp:0"
_get_unique_name("tp") -> "tp:1"
"""
if name not in _group_name_counter:
_group_name_counter[name] = 0
newname = f"{name}:{_group_name_counter[name]}"
_group_name_counter[name] += 1
return newname
_groups: Dict[str, Callable[[], "GroupCoordinator"]] = {}
def _register_group(group: "GroupCoordinator") -> None:
# looks like Python 3.8 does not understand `ReferenceType`
_groups[group.unique_name] = weakref.ref(group) # type: ignore
if supports_custom_op():
@torch.library.custom_op("vllm::inplace_all_reduce",
mutates_args=["tensor"])
def inplace_all_reduce(tensor: torch.Tensor, group_name: str) -> None:
assert group_name in _groups, f"Group {group_name} is not found."
group = _groups[group_name]()
if group is None:
raise ValueError(f"Group {group_name} is destroyed.")
group._all_reduce(tensor)
@inplace_all_reduce.register_fake
def _(tensor: torch.Tensor, group_name: str) -> None:
return
@torch.library.custom_op("vllm::outplace_all_reduce", mutates_args=[])
def outplace_all_reduce(tensor: torch.Tensor,
group_name: str) -> torch.Tensor:
assert group_name in _groups, f"Group {group_name} is not found."
group = _groups[group_name]()
if group is None:
raise ValueError(f"Group {group_name} is destroyed.")
return group._all_reduce(tensor)
@outplace_all_reduce.register_fake
def _(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
return torch.empty_like(tensor)
class GroupCoordinator: class GroupCoordinator:
""" """
PyTorch ProcessGroup wrapper for a group of processes. PyTorch ProcessGroup wrapper for a group of processes.
...@@ -111,7 +167,11 @@ class GroupCoordinator: ...@@ -111,7 +167,11 @@ class GroupCoordinator:
use_custom_allreduce: bool, use_custom_allreduce: bool,
use_tpu_communicator: bool, use_tpu_communicator: bool,
use_message_queue_broadcaster: bool = False, use_message_queue_broadcaster: bool = False,
group_name: Optional[str] = None,
): ):
group_name = group_name or "anonymous"
self.unique_name = _get_unique_name(group_name)
_register_group(self)
self.rank = torch.distributed.get_rank() self.rank = torch.distributed.get_rank()
self.local_rank = local_rank self.local_rank = local_rank
...@@ -134,7 +194,7 @@ class GroupCoordinator: ...@@ -134,7 +194,7 @@ class GroupCoordinator:
assert self.cpu_group is not None assert self.cpu_group is not None
assert self.device_group is not None assert self.device_group is not None
if torch.cuda.is_available(): if current_platform.is_cuda_alike():
self.device = torch.device(f"cuda:{local_rank}") self.device = torch.device(f"cuda:{local_rank}")
else: else:
self.device = torch.device("cpu") self.device = torch.device("cpu")
...@@ -149,28 +209,24 @@ class GroupCoordinator: ...@@ -149,28 +209,24 @@ class GroupCoordinator:
from vllm.distributed.device_communicators.pynccl import ( from vllm.distributed.device_communicators.pynccl import (
PyNcclCommunicator) PyNcclCommunicator)
self.pynccl_comm: Optional[PyNcclCommunicator] self.pynccl_comm: Optional[PyNcclCommunicator] = None
if use_pynccl and self.world_size > 1: if use_pynccl and self.world_size > 1:
self.pynccl_comm = PyNcclCommunicator( self.pynccl_comm = PyNcclCommunicator(
group=self.cpu_group, group=self.cpu_group,
device=self.device, device=self.device,
) )
else:
self.pynccl_comm = None
self.ca_comm: Optional[CustomAllreduce] self.ca_comm: Optional[CustomAllreduce] = None
if use_custom_allreduce and self.world_size > 1: if use_custom_allreduce and self.world_size > 1:
# Initialize a custom fast all-reduce implementation. # Initialize a custom fast all-reduce implementation.
self.ca_comm = CustomAllreduce( self.ca_comm = CustomAllreduce(
group=self.cpu_group, group=self.cpu_group,
device=self.device, device=self.device,
) )
else:
self.ca_comm = None
from vllm.distributed.device_communicators.tpu_communicator import ( from vllm.distributed.device_communicators.tpu_communicator import (
TpuCommunicator) TpuCommunicator)
self.tpu_communicator: Optional[TpuCommunicator] self.tpu_communicator: Optional[TpuCommunicator] = None
if use_tpu_communicator and self.world_size > 1: if use_tpu_communicator and self.world_size > 1:
self.tpu_communicator = TpuCommunicator(group=self.cpu_group) self.tpu_communicator = TpuCommunicator(group=self.cpu_group)
...@@ -264,16 +320,49 @@ class GroupCoordinator: ...@@ -264,16 +320,49 @@ class GroupCoordinator:
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
""" """
User-facing all-reduce function before we actually call the
all-reduce operation.
We need this because Dynamo does not support passing an arbitrary
object (`self` in this case) to a custom op. We need to pass the
group name as a string, and then look up the group coordinator from
the group name, dispatch the all-reduce operation to the group
coordinator.
In addition, PyTorch custom ops do not support mutation or returning
a new tensor in the same op. So we need to figure out if the op is
in-place or out-of-place ahead of time.
"""
# Bypass the function if we are using only 1 GPU.
if self.world_size == 1:
return input_
if not supports_custom_op():
return self._all_reduce(input_)
if self.tpu_communicator is not None and \
not self.tpu_communicator.disabled:
# TPU handles Dynamo with its own logic.
return self._all_reduce(input_)
if self.ca_comm is not None and self.ca_comm.should_custom_ar(input_):
return torch.ops.vllm.outplace_all_reduce(
input_, group_name=self.unique_name)
else:
torch.ops.vllm.inplace_all_reduce(input_,
group_name=self.unique_name)
return input_
def _all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
"""
The actual all-reduce implementation.
NOTE: This operation will be applied in-place or out-of-place. NOTE: This operation will be applied in-place or out-of-place.
Always assume this function modifies its input, but use the return Always assume this function modifies its input, but use the return
value as the output. value as the output.
""" """
ca_comm = self.ca_comm ca_comm = self.ca_comm
# Bypass the function if we are using only 1 GPU.
if self.world_size == 1:
return input_
# For TPUs, use TPU communicator. # For TPUs, use TPU communicator.
tpu_comm = self.tpu_communicator tpu_comm = self.tpu_communicator
if tpu_comm is not None and not tpu_comm.disabled: if tpu_comm is not None and not tpu_comm.disabled:
...@@ -758,6 +847,7 @@ def init_world_group(ranks: List[int], local_rank: int, ...@@ -758,6 +847,7 @@ def init_world_group(ranks: List[int], local_rank: int,
use_pynccl=False, use_pynccl=False,
use_custom_allreduce=False, use_custom_allreduce=False,
use_tpu_communicator=False, use_tpu_communicator=False,
group_name="world",
) )
...@@ -767,6 +857,7 @@ def init_model_parallel_group( ...@@ -767,6 +857,7 @@ def init_model_parallel_group(
backend: str, backend: str,
use_custom_allreduce: Optional[bool] = None, use_custom_allreduce: Optional[bool] = None,
use_message_queue_broadcaster: bool = False, use_message_queue_broadcaster: bool = False,
group_name: Optional[str] = None,
) -> GroupCoordinator: ) -> GroupCoordinator:
if use_custom_allreduce is None: if use_custom_allreduce is None:
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
...@@ -778,6 +869,7 @@ def init_model_parallel_group( ...@@ -778,6 +869,7 @@ def init_model_parallel_group(
use_custom_allreduce=use_custom_allreduce, use_custom_allreduce=use_custom_allreduce,
use_tpu_communicator=True, use_tpu_communicator=True,
use_message_queue_broadcaster=use_message_queue_broadcaster, use_message_queue_broadcaster=use_message_queue_broadcaster,
group_name=group_name,
) )
...@@ -931,7 +1023,8 @@ def initialize_model_parallel( ...@@ -931,7 +1023,8 @@ def initialize_model_parallel(
_TP = init_model_parallel_group(group_ranks, _TP = init_model_parallel_group(group_ranks,
get_world_group().local_rank, get_world_group().local_rank,
backend, backend,
use_message_queue_broadcaster=True) use_message_queue_broadcaster=True,
group_name="tp")
# Build the pipeline model-parallel groups. # Build the pipeline model-parallel groups.
num_pipeline_model_parallel_groups: int = (world_size // num_pipeline_model_parallel_groups: int = (world_size //
...@@ -947,7 +1040,8 @@ def initialize_model_parallel( ...@@ -947,7 +1040,8 @@ def initialize_model_parallel(
_PP = init_model_parallel_group(group_ranks, _PP = init_model_parallel_group(group_ranks,
get_world_group().local_rank, get_world_group().local_rank,
backend, backend,
use_custom_allreduce=False) use_custom_allreduce=False,
group_name="pp")
def ensure_model_parallel_initialized( def ensure_model_parallel_initialized(
......
...@@ -44,22 +44,36 @@ def nullable_str(val: str): ...@@ -44,22 +44,36 @@ def nullable_str(val: str):
def nullable_kvs(val: str) -> Optional[Mapping[str, int]]: def nullable_kvs(val: str) -> Optional[Mapping[str, int]]:
"""Parses a string containing comma separate key [str] to value [int]
pairs into a dictionary.
Args:
val: String value to be parsed.
Returns:
Dictionary with parsed values.
"""
if len(val) == 0: if len(val) == 0:
return None return None
out_dict: Dict[str, int] = {} out_dict: Dict[str, int] = {}
for item in val.split(","): for item in val.split(","):
try: kv_parts = [part.lower().strip() for part in item.split("=")]
key, value = item.split("=") if len(kv_parts) != 2:
except TypeError as exc: raise argparse.ArgumentTypeError(
msg = "Each item should be in the form KEY=VALUE" "Each item should be in the form KEY=VALUE")
raise ValueError(msg) from exc key, value = kv_parts
try: try:
out_dict[key] = int(value) parsed_value = int(value)
except ValueError as exc: except ValueError as exc:
msg = f"Failed to parse value of item {key}={value}" msg = f"Failed to parse value of item {key}={value}"
raise ValueError(msg) from exc raise argparse.ArgumentTypeError(msg) from exc
if key in out_dict and out_dict[key] != parsed_value:
raise argparse.ArgumentTypeError(
f"Conflicting values specified for key: {key}")
out_dict[key] = parsed_value
return out_dict return out_dict
...@@ -131,6 +145,7 @@ class EngineArgs: ...@@ -131,6 +145,7 @@ class EngineArgs:
max_cpu_loras: Optional[int] = None max_cpu_loras: Optional[int] = None
device: str = 'auto' device: str = 'auto'
num_scheduler_steps: int = 1 num_scheduler_steps: int = 1
multi_step_stream_outputs: bool = False
ray_workers_use_nsight: bool = False ray_workers_use_nsight: bool = False
num_gpu_blocks_override: Optional[int] = None num_gpu_blocks_override: Optional[int] = None
num_lookahead_slots: int = 0 num_lookahead_slots: int = 0
...@@ -161,6 +176,7 @@ class EngineArgs: ...@@ -161,6 +176,7 @@ class EngineArgs:
collect_detailed_traces: Optional[str] = None collect_detailed_traces: Optional[str] = None
disable_async_output_proc: bool = False disable_async_output_proc: bool = False
override_neuron_config: Optional[Dict[str, Any]] = None override_neuron_config: Optional[Dict[str, Any]] = None
mm_processor_kwargs: Optional[Dict[str, Any]] = None
def __post_init__(self): def __post_init__(self):
if self.tokenizer is None: if self.tokenizer is None:
...@@ -458,7 +474,10 @@ class EngineArgs: ...@@ -458,7 +474,10 @@ class EngineArgs:
default=EngineArgs.max_seq_len_to_capture, default=EngineArgs.max_seq_len_to_capture,
help='Maximum sequence length covered by CUDA ' help='Maximum sequence length covered by CUDA '
'graphs. When a sequence has context length ' 'graphs. When a sequence has context length '
'larger than this, we fall back to eager mode.') 'larger than this, we fall back to eager mode. '
'Additionally for encoder-decoder models, if the '
'sequence length of the encoder input is larger '
'than this, we fall back to the eager mode.')
parser.add_argument('--disable-custom-all-reduce', parser.add_argument('--disable-custom-all-reduce',
action='store_true', action='store_true',
default=EngineArgs.disable_custom_all_reduce, default=EngineArgs.disable_custom_all_reduce,
...@@ -496,6 +515,12 @@ class EngineArgs: ...@@ -496,6 +515,12 @@ class EngineArgs:
'e.g.: `image=16,video=2` allows a maximum of 16 ' 'e.g.: `image=16,video=2` allows a maximum of 16 '
'images and 2 videos per prompt. Defaults to 1 for ' 'images and 2 videos per prompt. Defaults to 1 for '
'each modality.')) 'each modality.'))
parser.add_argument(
'--mm-processor-kwargs',
default=None,
type=json.loads,
help=('Overrides for the multimodal input mapping/processing,'
'e.g., image processor. For example: {"num_crops": 4}.'))
# LoRA related configs # LoRA related configs
parser.add_argument('--enable-lora', parser.add_argument('--enable-lora',
...@@ -571,6 +596,10 @@ class EngineArgs: ...@@ -571,6 +596,10 @@ class EngineArgs:
help=('Maximum number of forward steps per ' help=('Maximum number of forward steps per '
'scheduler call.')) 'scheduler call.'))
parser.add_argument(
'--multi-step-stream-outputs',
action='store_true',
help='If True, then multi-step will stream outputs for every step')
parser.add_argument( parser.add_argument(
'--scheduler-delay-factor', '--scheduler-delay-factor',
type=float, type=float,
...@@ -805,6 +834,7 @@ class EngineArgs: ...@@ -805,6 +834,7 @@ class EngineArgs:
use_async_output_proc=not self.disable_async_output_proc, use_async_output_proc=not self.disable_async_output_proc,
override_neuron_config=self.override_neuron_config, override_neuron_config=self.override_neuron_config,
config_format=self.config_format, config_format=self.config_format,
mm_processor_kwargs=self.mm_processor_kwargs,
) )
def create_load_config(self) -> LoadConfig: def create_load_config(self) -> LoadConfig:
...@@ -974,6 +1004,7 @@ class EngineArgs: ...@@ -974,6 +1004,7 @@ class EngineArgs:
is_multimodal_model=model_config.is_multimodal_model, is_multimodal_model=model_config.is_multimodal_model,
preemption_mode=self.preemption_mode, preemption_mode=self.preemption_mode,
num_scheduler_steps=self.num_scheduler_steps, num_scheduler_steps=self.num_scheduler_steps,
multi_step_stream_outputs=self.multi_step_stream_outputs,
send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER
and parallel_config.use_ray), and parallel_config.use_ray),
) )
......
import asyncio import asyncio
import time import time
import weakref
from functools import partial from functools import partial
from typing import (Any, AsyncGenerator, Callable, Dict, Iterable, List, from typing import (Any, AsyncGenerator, Callable, Dict, Iterable, List,
Mapping, Optional, Set, Tuple, Type, Union) Mapping, Optional, Set, Tuple, Type, Union)
from weakref import ReferenceType
import vllm.envs as envs import vllm.envs as envs
from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig, from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig,
...@@ -26,6 +28,7 @@ from vllm.sampling_params import SamplingParams ...@@ -26,6 +28,7 @@ from vllm.sampling_params import SamplingParams
from vllm.sequence import ExecuteModelRequest from vllm.sequence import ExecuteModelRequest
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils import weak_bind
logger = init_logger(__name__) logger = init_logger(__name__)
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
...@@ -450,9 +453,6 @@ class AsyncLLMEngine: ...@@ -450,9 +453,6 @@ class AsyncLLMEngine:
method yields the outputs from the :class:`LLMEngine` to the caller. method yields the outputs from the :class:`LLMEngine` to the caller.
Args: Args:
worker_use_ray: Whether to use Ray for model workers. Required for
distributed execution. Should be the same as
`parallel_config.worker_use_ray`.
log_requests: Whether to log the requests. log_requests: Whether to log the requests.
start_engine_loop: If True, the background task to run the engine start_engine_loop: If True, the background task to run the engine
will be automatically started in the generate call. will be automatically started in the generate call.
...@@ -463,23 +463,22 @@ class AsyncLLMEngine: ...@@ -463,23 +463,22 @@ class AsyncLLMEngine:
_engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine _engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine
def __init__(self, def __init__(self,
worker_use_ray: bool,
*args, *args,
log_requests: bool = True, log_requests: bool = True,
start_engine_loop: bool = True, start_engine_loop: bool = True,
**kwargs) -> None: **kwargs) -> None:
self.worker_use_ray = worker_use_ray
self.log_requests = log_requests self.log_requests = log_requests
self.engine = self._engine_class(*args, **kwargs) self.engine = self._engine_class(*args, **kwargs)
# This ensures quick processing of request outputs # This ensures quick processing of request outputs
# so the append to asyncio queues is not delayed, # so the append to asyncio queues is not delayed,
# especially for multi-step. # especially for multi-step.
# self.use_process_request_outputs_callback = (
self.use_process_request_outputs_callback = True self.engine.model_config.use_async_output_proc)
if self.use_process_request_outputs_callback: if self.use_process_request_outputs_callback:
self.engine.process_request_outputs_callback = \ self.engine.process_request_outputs_callback = \
self.process_request_outputs weak_bind(self.process_request_outputs)
self.background_loop: Optional[asyncio.Future] = None self.background_loop: Optional[asyncio.Future] = None
# We need to keep a reference to unshielded # We need to keep a reference to unshielded
...@@ -492,6 +491,11 @@ class AsyncLLMEngine: ...@@ -492,6 +491,11 @@ class AsyncLLMEngine:
# Lazy initialized fields # Lazy initialized fields
self._request_tracker: RequestTracker self._request_tracker: RequestTracker
def __del__(self):
if rt := getattr(self, "request_tracker", None):
# Wake up engine loop so that it will exit cleanly
rt.new_requests_event.set()
@classmethod @classmethod
def _get_executor_cls( def _get_executor_cls(
cls, engine_config: EngineConfig) -> Type[ExecutorAsyncBase]: cls, engine_config: EngineConfig) -> Type[ExecutorAsyncBase]:
...@@ -502,15 +506,12 @@ class AsyncLLMEngine: ...@@ -502,15 +506,12 @@ class AsyncLLMEngine:
raise TypeError( raise TypeError(
"distributed_executor_backend must be a subclass of " "distributed_executor_backend must be a subclass of "
f"ExecutorAsyncBase. Got {distributed_executor_backend}.") f"ExecutorAsyncBase. Got {distributed_executor_backend}.")
if distributed_executor_backend.uses_ray: # type: ignore
initialize_ray_cluster(engine_config.parallel_config)
executor_class = distributed_executor_backend executor_class = distributed_executor_backend
elif engine_config.device_config.device_type == "neuron": elif engine_config.device_config.device_type == "neuron":
from vllm.executor.neuron_executor import NeuronExecutorAsync from vllm.executor.neuron_executor import NeuronExecutorAsync
executor_class = NeuronExecutorAsync executor_class = NeuronExecutorAsync
elif engine_config.device_config.device_type == "tpu": elif engine_config.device_config.device_type == "tpu":
if distributed_executor_backend == "ray": if distributed_executor_backend == "ray":
initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.ray_tpu_executor import RayTPUExecutorAsync from vllm.executor.ray_tpu_executor import RayTPUExecutorAsync
executor_class = RayTPUExecutorAsync executor_class = RayTPUExecutorAsync
else: else:
...@@ -531,11 +532,9 @@ class AsyncLLMEngine: ...@@ -531,11 +532,9 @@ class AsyncLLMEngine:
from vllm.executor.xpu_executor import XPUExecutorAsync from vllm.executor.xpu_executor import XPUExecutorAsync
executor_class = XPUExecutorAsync executor_class = XPUExecutorAsync
elif distributed_executor_backend == "ray": elif distributed_executor_backend == "ray":
initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.ray_xpu_executor import RayXPUExecutorAsync from vllm.executor.ray_xpu_executor import RayXPUExecutorAsync
executor_class = RayXPUExecutorAsync executor_class = RayXPUExecutorAsync
elif distributed_executor_backend == "mp": elif distributed_executor_backend == "mp":
initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.multiproc_xpu_executor import ( from vllm.executor.multiproc_xpu_executor import (
MultiprocessingXPUExecutorAsync) MultiprocessingXPUExecutorAsync)
executor_class = MultiprocessingXPUExecutorAsync executor_class = MultiprocessingXPUExecutorAsync
...@@ -543,7 +542,6 @@ class AsyncLLMEngine: ...@@ -543,7 +542,6 @@ class AsyncLLMEngine:
raise RuntimeError( raise RuntimeError(
"Not supported distributed execution model on XPU device.") "Not supported distributed execution model on XPU device.")
elif distributed_executor_backend == "ray": elif distributed_executor_backend == "ray":
initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
executor_class = RayGPUExecutorAsync executor_class = RayGPUExecutorAsync
elif distributed_executor_backend == "mp": elif distributed_executor_backend == "mp":
...@@ -559,19 +557,23 @@ class AsyncLLMEngine: ...@@ -559,19 +557,23 @@ class AsyncLLMEngine:
def from_engine_args( def from_engine_args(
cls, cls,
engine_args: AsyncEngineArgs, engine_args: AsyncEngineArgs,
engine_config: Optional[EngineConfig] = None,
start_engine_loop: bool = True, start_engine_loop: bool = True,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
) -> "AsyncLLMEngine": ) -> "AsyncLLMEngine":
"""Creates an async LLM engine from the engine arguments.""" """Creates an async LLM engine from the engine arguments."""
# Create the engine configs. # Create the engine configs.
engine_config = engine_args.create_engine_config() if engine_config is None:
engine_config = engine_args.create_engine_config()
executor_class = cls._get_executor_cls(engine_config) executor_class = cls._get_executor_cls(engine_config)
if executor_class.uses_ray:
initialize_ray_cluster(engine_config.parallel_config)
# Create the async LLM engine. # Create the async LLM engine.
engine = cls( engine = cls(
executor_class.uses_ray,
**engine_config.to_dict(), **engine_config.to_dict(),
executor_class=executor_class, executor_class=executor_class,
log_requests=not engine_args.disable_log_requests, log_requests=not engine_args.disable_log_requests,
...@@ -599,9 +601,12 @@ class AsyncLLMEngine: ...@@ -599,9 +601,12 @@ class AsyncLLMEngine:
return self._errored_with is not None return self._errored_with is not None
@property @property
def limit_concurrency(self) -> Optional[int]: def dead_error(self) -> BaseException:
"""Maximum number of concurrently running requests.""" return AsyncEngineDeadError(
return None "Background loop is not running. If it was running, "
"inspect the output to find the stacktrace of the "
"error that caused the background loop to stop "
"(AsyncEngineDeadError).")
def set_errored(self, exc: Exception) -> None: def set_errored(self, exc: Exception) -> None:
self._errored_with = exc self._errored_with = exc
...@@ -628,7 +633,7 @@ class AsyncLLMEngine: ...@@ -628,7 +633,7 @@ class AsyncLLMEngine:
self._request_tracker = RequestTracker() self._request_tracker = RequestTracker()
self._background_loop_unshielded = asyncio.get_event_loop( self._background_loop_unshielded = asyncio.get_event_loop(
).create_task(self.run_engine_loop()) ).create_task(self.run_engine_loop(weakref.ref(self)))
self._background_loop_unshielded.add_done_callback( self._background_loop_unshielded.add_done_callback(
partial(_log_task_completion, error_callback=self._error_callback)) partial(_log_task_completion, error_callback=self._error_callback))
self.background_loop = asyncio.shield(self._background_loop_unshielded) self.background_loop = asyncio.shield(self._background_loop_unshielded)
...@@ -698,9 +703,16 @@ class AsyncLLMEngine: ...@@ -698,9 +703,16 @@ class AsyncLLMEngine:
async def _engine_abort(self, request_ids: Iterable[str]): async def _engine_abort(self, request_ids: Iterable[str]):
self.engine.abort_request(request_ids) self.engine.abort_request(request_ids)
async def run_engine_loop(self): @staticmethod
async def run_engine_loop(engine_ref: ReferenceType):
"""We use a weakref to the engine so that the running loop
doesn't prevent the engine being garbage collected."""
engine: Optional["AsyncLLMEngine"] = engine_ref()
if not engine:
return
pipeline_parallel_size = \ pipeline_parallel_size = \
self.engine.parallel_config.pipeline_parallel_size engine.engine.parallel_config.pipeline_parallel_size
has_requests_in_progress = [False] * pipeline_parallel_size has_requests_in_progress = [False] * pipeline_parallel_size
while True: while True:
if not any(has_requests_in_progress): if not any(has_requests_in_progress):
...@@ -711,11 +723,21 @@ class AsyncLLMEngine: ...@@ -711,11 +723,21 @@ class AsyncLLMEngine:
# timeout, and unblocks the RPC thread in the workers so that # timeout, and unblocks the RPC thread in the workers so that
# they can process any other queued control plane messages, # they can process any other queued control plane messages,
# such as add/remove lora adapters. # such as add/remove lora adapters.
await self.engine.stop_remote_worker_execution_loop_async() await engine.engine.stop_remote_worker_execution_loop_async()
await self._request_tracker.wait_for_new_requests() request_tracker = engine._request_tracker
# Allow engine to be garbage collected while
# waiting for new requests
del engine
await asyncio.sleep(0)
if engine_ref() is None:
return
await request_tracker.wait_for_new_requests()
engine = engine_ref()
if not engine:
return
logger.debug("Got new requests!") logger.debug("Got new requests!")
requests_in_progress = [ requests_in_progress = [
asyncio.create_task(self.engine_step(ve)) asyncio.create_task(engine.engine_step(ve))
for ve in range(pipeline_parallel_size) for ve in range(pipeline_parallel_size)
] ]
has_requests_in_progress = [True] * pipeline_parallel_size has_requests_in_progress = [True] * pipeline_parallel_size
...@@ -733,19 +755,20 @@ class AsyncLLMEngine: ...@@ -733,19 +755,20 @@ class AsyncLLMEngine:
result = task.result() result = task.result()
virtual_engine = requests_in_progress.index(task) virtual_engine = requests_in_progress.index(task)
has_unfinished_requests = ( has_unfinished_requests = (
self.engine.has_unfinished_requests_for_virtual_engine( engine.engine.
has_unfinished_requests_for_virtual_engine(
virtual_engine)) virtual_engine))
if result or has_unfinished_requests: if result or has_unfinished_requests:
requests_in_progress[virtual_engine] = ( requests_in_progress[virtual_engine] = (
asyncio.create_task( asyncio.create_task(
self.engine_step(virtual_engine))) engine.engine_step(virtual_engine)))
has_requests_in_progress[virtual_engine] = True has_requests_in_progress[virtual_engine] = True
else: else:
has_requests_in_progress[virtual_engine] = False has_requests_in_progress[virtual_engine] = False
except asyncio.TimeoutError as exc: except asyncio.TimeoutError as exc:
logger.error( logger.error(
"Engine iteration timed out. This should never happen!") "Engine iteration timed out. This should never happen!")
self.set_errored(exc) engine.set_errored(exc)
raise raise
await asyncio.sleep(0) await asyncio.sleep(0)
...@@ -806,7 +829,7 @@ class AsyncLLMEngine: ...@@ -806,7 +829,7 @@ class AsyncLLMEngine:
request_id: The unique id of the request. request_id: The unique id of the request.
lora_request: LoRA request to use for generation, if any. lora_request: LoRA request to use for generation, if any.
trace_headers: OpenTelemetry trace headers. trace_headers: OpenTelemetry trace headers.
prompt_adapter_request: Prompt Adapter request to use prompt_adapter_request: Prompt Adapter request to use
for generation, if any. for generation, if any.
Yields: Yields:
...@@ -1022,7 +1045,7 @@ class AsyncLLMEngine: ...@@ -1022,7 +1045,7 @@ class AsyncLLMEngine:
async def start_profile(self) -> None: async def start_profile(self) -> None:
# using type instead of isinstance to check to avoid capturing # using type instead of isinstance to check to avoid capturing
# inherited classes # inherited classes
if type(self.engine.model_executor) == GPUExecutorAsync: if type(self.engine.model_executor) == GPUExecutorAsync: # noqa: E721
self.engine.model_executor.start_profile() self.engine.model_executor.start_profile()
else: else:
self.engine.model_executor._run_workers("start_profile") self.engine.model_executor._run_workers("start_profile")
...@@ -1030,7 +1053,7 @@ class AsyncLLMEngine: ...@@ -1030,7 +1053,7 @@ class AsyncLLMEngine:
async def stop_profile(self) -> None: async def stop_profile(self) -> None:
# using type instead of isinstance to check to avoid capturing # using type instead of isinstance to check to avoid capturing
# inherited classes # inherited classes
if type(self.engine.model_executor) == GPUExecutorAsync: if type(self.engine.model_executor) == GPUExecutorAsync: # noqa: E721
self.engine.model_executor.stop_profile() self.engine.model_executor.stop_profile()
else: else:
self.engine.model_executor._run_workers("stop_profile") self.engine.model_executor._run_workers("stop_profile")
import functools
import time import time
from collections import deque from collections import deque
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from functools import partial
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict, from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict,
Iterable, List, Mapping, NamedTuple, Optional) Iterable, List, Mapping, NamedTuple, Optional)
from typing import Sequence as GenericSequence from typing import Sequence as GenericSequence
...@@ -51,7 +51,7 @@ from vllm.transformers_utils.tokenizer_group import ( ...@@ -51,7 +51,7 @@ from vllm.transformers_utils.tokenizer_group import (
BaseTokenizerGroup, init_tokenizer_from_configs) BaseTokenizerGroup, init_tokenizer_from_configs)
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message) usage_message)
from vllm.utils import Counter, Device from vllm.utils import Counter, Device, weak_bind
from vllm.version import __version__ as VLLM_VERSION from vllm.version import __version__ as VLLM_VERSION
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -95,7 +95,7 @@ class OutputData(NamedTuple): ...@@ -95,7 +95,7 @@ class OutputData(NamedTuple):
class SchedulerContext: class SchedulerContext:
def __init__(self): def __init__(self, multi_step_stream_outputs: bool = False):
self.output_queue: Deque[OutputData] = deque() self.output_queue: Deque[OutputData] = deque()
self.request_outputs: List[Union[RequestOutput, self.request_outputs: List[Union[RequestOutput,
EmbeddingRequestOutput]] = [] EmbeddingRequestOutput]] = []
...@@ -103,6 +103,8 @@ class SchedulerContext: ...@@ -103,6 +103,8 @@ class SchedulerContext:
List[SequenceGroupMetadata]] = None List[SequenceGroupMetadata]] = None
self.scheduler_outputs: Optional[SchedulerOutputs] = None self.scheduler_outputs: Optional[SchedulerOutputs] = None
self.multi_step_stream_outputs: bool = multi_step_stream_outputs
def append_output(self, outputs: List[SamplerOutput], def append_output(self, outputs: List[SamplerOutput],
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
scheduler_outputs: SchedulerOutputs, is_async: bool, scheduler_outputs: SchedulerOutputs, is_async: bool,
...@@ -144,7 +146,7 @@ class LLMEngine: ...@@ -144,7 +146,7 @@ class LLMEngine:
decoding. decoding.
executor_class: The model executor class for managing distributed executor_class: The model executor class for managing distributed
execution. execution.
prompt_adapter_config (Optional): The configuration related to serving prompt_adapter_config (Optional): The configuration related to serving
prompt adapters. prompt adapters.
log_stats: Whether to log statistics. log_stats: Whether to log statistics.
usage_context: Specified entry point, used for usage info collection. usage_context: Specified entry point, used for usage info collection.
...@@ -219,6 +221,7 @@ class LLMEngine: ...@@ -219,6 +221,7 @@ class LLMEngine:
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
input_registry: InputRegistry = INPUT_REGISTRY, input_registry: InputRegistry = INPUT_REGISTRY,
use_cached_outputs: bool = False,
) -> None: ) -> None:
logger.info( logger.info(
"Initializing an LLM engine (v%s) with config: " "Initializing an LLM engine (v%s) with config: "
...@@ -234,8 +237,9 @@ class LLMEngine: ...@@ -234,8 +237,9 @@ class LLMEngine:
"quantization_param_path=%s, device_config=%s, " "quantization_param_path=%s, device_config=%s, "
"decoding_config=%r, observability_config=%r, " "decoding_config=%r, observability_config=%r, "
"seed=%d, served_model_name=%s, use_v2_block_manager=%s, " "seed=%d, served_model_name=%s, use_v2_block_manager=%s, "
"num_scheduler_steps=%d, enable_prefix_caching=%s, " "num_scheduler_steps=%d, multi_step_stream_outputs=%s, "
"use_async_output_proc=%s)", "enable_prefix_caching=%s, use_async_output_proc=%s, "
"use_cached_outputs=%s, mm_processor_kwargs=%s)",
VLLM_VERSION, VLLM_VERSION,
model_config.model, model_config.model,
speculative_config, speculative_config,
...@@ -266,8 +270,11 @@ class LLMEngine: ...@@ -266,8 +270,11 @@ class LLMEngine:
model_config.served_model_name, model_config.served_model_name,
scheduler_config.use_v2_block_manager, scheduler_config.use_v2_block_manager,
scheduler_config.num_scheduler_steps, scheduler_config.num_scheduler_steps,
scheduler_config.multi_step_stream_outputs,
cache_config.enable_prefix_caching, cache_config.enable_prefix_caching,
model_config.use_async_output_proc, model_config.use_async_output_proc,
use_cached_outputs,
model_config.mm_processor_kwargs,
) )
# TODO(woosuk): Print more configs in debug mode. # TODO(woosuk): Print more configs in debug mode.
from vllm.plugins import load_general_plugins from vllm.plugins import load_general_plugins
...@@ -286,6 +293,7 @@ class LLMEngine: ...@@ -286,6 +293,7 @@ class LLMEngine:
self.observability_config = observability_config or ObservabilityConfig( self.observability_config = observability_config or ObservabilityConfig(
) )
self.log_stats = log_stats self.log_stats = log_stats
self.use_cached_outputs = use_cached_outputs
if not self.model_config.skip_tokenizer_init: if not self.model_config.skip_tokenizer_init:
self.tokenizer = self._init_tokenizer() self.tokenizer = self._init_tokenizer()
...@@ -327,136 +335,134 @@ class LLMEngine: ...@@ -327,136 +335,134 @@ class LLMEngine:
observability_config=self.observability_config, observability_config=self.observability_config,
) )
init_success = False if not self.model_config.embedding_mode:
try: self._initialize_kv_caches()
if not self.model_config.embedding_mode:
self._initialize_kv_caches() # If usage stat is enabled, collect relevant info.
if is_usage_stats_enabled():
# If usage stat is enabled, collect relevant info. from vllm.model_executor.model_loader import (
if is_usage_stats_enabled(): get_architecture_class_name)
from vllm.model_executor.model_loader import ( usage_message.report_usage(
get_architecture_class_name) get_architecture_class_name(model_config),
usage_message.report_usage( usage_context,
get_architecture_class_name(model_config), extra_kvs={
usage_context, # Common configuration
extra_kvs={ "dtype":
# Common configuration str(model_config.dtype),
"dtype": "tensor_parallel_size":
str(model_config.dtype), parallel_config.tensor_parallel_size,
"tensor_parallel_size": "block_size":
parallel_config.tensor_parallel_size, cache_config.block_size,
"block_size": "gpu_memory_utilization":
cache_config.block_size, cache_config.gpu_memory_utilization,
"gpu_memory_utilization":
cache_config.gpu_memory_utilization, # Quantization
"quantization":
# Quantization model_config.quantization,
"quantization": "kv_cache_dtype":
model_config.quantization, str(cache_config.cache_dtype),
"kv_cache_dtype":
str(cache_config.cache_dtype), # Feature flags
"enable_lora":
# Feature flags bool(lora_config),
"enable_lora": "enable_prompt_adapter":
bool(lora_config), bool(prompt_adapter_config),
"enable_prompt_adapter": "enable_prefix_caching":
bool(prompt_adapter_config), cache_config.enable_prefix_caching,
"enable_prefix_caching": "enforce_eager":
cache_config.enable_prefix_caching, model_config.enforce_eager,
"enforce_eager": "disable_custom_all_reduce":
model_config.enforce_eager, parallel_config.disable_custom_all_reduce,
"disable_custom_all_reduce": })
parallel_config.disable_custom_all_reduce,
})
if self.tokenizer:
# Ping the tokenizer to ensure liveness if it runs in a
# different process.
self.tokenizer.ping()
self.cached_scheduler_outputs = [
SchedulerOutputState()
for _ in range(self.parallel_config.pipeline_parallel_size)
]
self.scheduler_contexts = [ if self.tokenizer:
SchedulerContext() # Ping the tokenizer to ensure liveness if it runs in a
for _ in range(self.parallel_config.pipeline_parallel_size) # different process.
] self.tokenizer.ping()
self.cached_scheduler_outputs = [
SchedulerOutputState()
for _ in range(self.parallel_config.pipeline_parallel_size)
]
self.scheduler_contexts = [
SchedulerContext(multi_step_stream_outputs=self.scheduler_config.
multi_step_stream_outputs)
for _ in range(self.parallel_config.pipeline_parallel_size)
]
if model_config.use_async_output_proc:
process_model_outputs = weak_bind(self._process_model_outputs)
self.async_callbacks = [ self.async_callbacks = [
functools.partial(self._process_model_outputs, partial(process_model_outputs,
ctx=self.scheduler_contexts[v_id]) ctx=self.scheduler_contexts[v_id])
for v_id in range(self.parallel_config.pipeline_parallel_size) for v_id in range(self.parallel_config.pipeline_parallel_size)
] ]
else:
self.async_callbacks = []
# Currently used by AsyncLLMEngine to ensure quick append
# of request outputs to asyncio queues
self.process_request_outputs_callback: Optional[Callable] = None
# Create the scheduler.
# NOTE: the cache_config here have been updated with the numbers of
# GPU and CPU blocks, which are profiled in the distributed executor.
self.scheduler = [
Scheduler(
scheduler_config, cache_config, lora_config,
parallel_config.pipeline_parallel_size,
self.async_callbacks[v_id]
if model_config.use_async_output_proc else None)
for v_id in range(parallel_config.pipeline_parallel_size)
]
# Currently used by AsyncLLMEngine to ensure quick append # Metric Logging.
# of request outputs to asyncio queues if self.log_stats:
self.process_request_outputs_callback = None if stat_loggers is not None:
self.stat_loggers = stat_loggers
# Create the scheduler. else:
# NOTE: the cache_config here have been updated with the numbers of # Lazy import for prometheus multiprocessing.
# GPU and CPU blocks, which are profiled in the distributed executor. # We need to set PROMETHEUS_MULTIPROC_DIR environment variable
self.scheduler = [ # before prometheus_client is imported.
Scheduler( # See https://prometheus.github.io/client_python/multiprocess/
scheduler_config, cache_config, lora_config, from vllm.engine.metrics import (LoggingStatLogger,
parallel_config.pipeline_parallel_size, PrometheusStatLogger)
self.async_callbacks[v_id]
if model_config.use_async_output_proc else None) self.stat_loggers = {
for v_id in range(parallel_config.pipeline_parallel_size) "logging":
] LoggingStatLogger(
local_interval=_LOCAL_LOGGING_INTERVAL_SEC),
# Metric Logging. "prometheus":
if self.log_stats: PrometheusStatLogger(
if stat_loggers is not None: local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
self.stat_loggers = stat_loggers labels=dict(model_name=model_config.served_model_name),
else: max_model_len=self.model_config.max_model_len),
# Lazy import for prometheus multiprocessing. }
# We need to set PROMETHEUS_MULTIPROC_DIR environment variable self.stat_loggers["prometheus"].info("cache_config",
# before prometheus_client is imported. self.cache_config)
# See https://prometheus.github.io/client_python/multiprocess/
from vllm.engine.metrics import (LoggingStatLogger, self.tracer = None
PrometheusStatLogger) if self.observability_config.otlp_traces_endpoint:
self.tracer = init_tracer(
self.stat_loggers = { "vllm.llm_engine",
"logging": self.observability_config.otlp_traces_endpoint)
LoggingStatLogger(
local_interval=_LOCAL_LOGGING_INTERVAL_SEC), # Create sequence output processor, e.g. for beam search or
"prometheus": # speculative decoding.
PrometheusStatLogger( self.output_processor = (
local_interval=_LOCAL_LOGGING_INTERVAL_SEC, SequenceGroupOutputProcessor.create_output_processor(
labels=dict(model_name=model_config.served_model_name), self.scheduler_config,
max_model_len=self.model_config.max_model_len), self.detokenizer,
} self.scheduler,
self.stat_loggers["prometheus"].info("cache_config", self.seq_counter,
self.cache_config) get_tokenizer_for_seq,
stop_checker=StopChecker(
self.tracer = None self.scheduler_config.max_model_len,
if self.observability_config.otlp_traces_endpoint:
self.tracer = init_tracer(
"vllm.llm_engine",
self.observability_config.otlp_traces_endpoint)
# Create sequence output processor, e.g. for beam search or
# speculative decoding.
self.output_processor = (
SequenceGroupOutputProcessor.create_output_processor(
self.scheduler_config,
self.detokenizer,
self.scheduler,
self.seq_counter,
get_tokenizer_for_seq, get_tokenizer_for_seq,
stop_checker=StopChecker( ),
self.scheduler_config.max_model_len, ))
get_tokenizer_for_seq,
),
))
init_success = True
finally:
if not init_success:
# Ensure that model_executor is shut down if LLMEngine init
# failed
self.model_executor.shutdown()
def _initialize_kv_caches(self) -> None: def _initialize_kv_caches(self) -> None:
"""Initialize the KV cache in the worker(s). """Initialize the KV cache in the worker(s).
...@@ -625,6 +631,7 @@ class LLMEngine: ...@@ -625,6 +631,7 @@ class LLMEngine:
lora_request: Optional[LoRARequest], lora_request: Optional[LoRARequest],
prompt_adapter_request: Optional[PromptAdapterRequest], prompt_adapter_request: Optional[PromptAdapterRequest],
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
) -> None: ) -> None:
self._validate_model_inputs(processed_inputs) self._validate_model_inputs(processed_inputs)
# Create the sequences. # Create the sequences.
...@@ -655,7 +662,8 @@ class LLMEngine: ...@@ -655,7 +662,8 @@ class LLMEngine:
lora_request=lora_request, lora_request=lora_request,
trace_headers=trace_headers, trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request, prompt_adapter_request=prompt_adapter_request,
encoder_seq=encoder_seq) encoder_seq=encoder_seq,
priority=priority)
elif isinstance(params, PoolingParams): elif isinstance(params, PoolingParams):
seq_group = self._create_sequence_group_with_pooling( seq_group = self._create_sequence_group_with_pooling(
request_id, request_id,
...@@ -664,7 +672,8 @@ class LLMEngine: ...@@ -664,7 +672,8 @@ class LLMEngine:
arrival_time=arrival_time, arrival_time=arrival_time,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request, prompt_adapter_request=prompt_adapter_request,
encoder_seq=encoder_seq) encoder_seq=encoder_seq,
priority=priority)
else: else:
raise ValueError( raise ValueError(
"Either SamplingParams or PoolingParams must be provided.") "Either SamplingParams or PoolingParams must be provided.")
...@@ -689,6 +698,7 @@ class LLMEngine: ...@@ -689,6 +698,7 @@ class LLMEngine:
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> None: ) -> None:
"""Add a request to the engine's request pool. """Add a request to the engine's request pool.
...@@ -707,6 +717,8 @@ class LLMEngine: ...@@ -707,6 +717,8 @@ class LLMEngine:
arrival_time: The arrival time of the request. If None, we use arrival_time: The arrival time of the request. If None, we use
the current monotonic time. the current monotonic time.
trace_headers: OpenTelemetry trace headers. trace_headers: OpenTelemetry trace headers.
priority: The priority of the request.
Only applicable with priority scheduling.
Details: Details:
- Set arrival_time to the current time if it is None. - Set arrival_time to the current time if it is None.
...@@ -735,6 +747,11 @@ class LLMEngine: ...@@ -735,6 +747,11 @@ class LLMEngine:
if lora_request is not None and not self.lora_config: if lora_request is not None and not self.lora_config:
raise ValueError(f"Got lora_request {lora_request} but LoRA is " raise ValueError(f"Got lora_request {lora_request} but LoRA is "
"not enabled!") "not enabled!")
if priority > 0 and not self.scheduler_config.policy == "priority":
raise ValueError(f"Got priority {priority} but "
"Priority scheduling is not enabled.")
if arrival_time is None: if arrival_time is None:
arrival_time = time.time() arrival_time = time.time()
...@@ -754,6 +771,7 @@ class LLMEngine: ...@@ -754,6 +771,7 @@ class LLMEngine:
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request, prompt_adapter_request=prompt_adapter_request,
trace_headers=trace_headers, trace_headers=trace_headers,
priority=priority,
) )
def _create_sequence_group_with_sampling( def _create_sequence_group_with_sampling(
...@@ -766,6 +784,7 @@ class LLMEngine: ...@@ -766,6 +784,7 @@ class LLMEngine:
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
encoder_seq: Optional[Sequence] = None, encoder_seq: Optional[Sequence] = None,
priority: int = 0,
) -> SequenceGroup: ) -> SequenceGroup:
"""Creates a SequenceGroup with SamplingParams.""" """Creates a SequenceGroup with SamplingParams."""
max_logprobs = self.get_model_config().max_logprobs max_logprobs = self.get_model_config().max_logprobs
...@@ -792,7 +811,8 @@ class LLMEngine: ...@@ -792,7 +811,8 @@ class LLMEngine:
lora_request=lora_request, lora_request=lora_request,
trace_headers=trace_headers, trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request, prompt_adapter_request=prompt_adapter_request,
encoder_seq=encoder_seq) encoder_seq=encoder_seq,
priority=priority)
return seq_group return seq_group
...@@ -805,6 +825,7 @@ class LLMEngine: ...@@ -805,6 +825,7 @@ class LLMEngine:
lora_request: Optional[LoRARequest], lora_request: Optional[LoRARequest],
prompt_adapter_request: Optional[PromptAdapterRequest], prompt_adapter_request: Optional[PromptAdapterRequest],
encoder_seq: Optional[Sequence] = None, encoder_seq: Optional[Sequence] = None,
priority: int = 0,
) -> SequenceGroup: ) -> SequenceGroup:
"""Creates a SequenceGroup with PoolingParams.""" """Creates a SequenceGroup with PoolingParams."""
# Defensive copy of PoolingParams, which are used by the pooler # Defensive copy of PoolingParams, which are used by the pooler
...@@ -817,7 +838,8 @@ class LLMEngine: ...@@ -817,7 +838,8 @@ class LLMEngine:
lora_request=lora_request, lora_request=lora_request,
pooling_params=pooling_params, pooling_params=pooling_params,
prompt_adapter_request=prompt_adapter_request, prompt_adapter_request=prompt_adapter_request,
encoder_seq=encoder_seq) encoder_seq=encoder_seq,
priority=priority)
return seq_group return seq_group
def abort_request(self, request_id: Union[str, Iterable[str]]) -> None: def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
...@@ -877,8 +899,8 @@ class LLMEngine: ...@@ -877,8 +899,8 @@ class LLMEngine:
""" """
return self.scheduler[virtual_engine].has_unfinished_seqs() return self.scheduler[virtual_engine].has_unfinished_seqs()
@staticmethod
def _process_sequence_group_outputs( def _process_sequence_group_outputs(
self,
seq_group: SequenceGroup, seq_group: SequenceGroup,
outputs: List[EmbeddingSequenceGroupOutput], outputs: List[EmbeddingSequenceGroupOutput],
) -> None: ) -> None:
...@@ -1001,7 +1023,8 @@ class LLMEngine: ...@@ -1001,7 +1023,8 @@ class LLMEngine:
seq_group = scheduled_seq_group.seq_group seq_group = scheduled_seq_group.seq_group
seq_group.maybe_set_first_token_time(now) seq_group.maybe_set_first_token_time(now)
request_output = RequestOutputFactory.create(seq_group) request_output = RequestOutputFactory.create(
seq_group, use_cache=self.use_cached_outputs)
if request_output: if request_output:
ctx.request_outputs.append(request_output) ctx.request_outputs.append(request_output)
...@@ -1022,8 +1045,8 @@ class LLMEngine: ...@@ -1022,8 +1045,8 @@ class LLMEngine:
for scheduler in self.scheduler: for scheduler in self.scheduler:
scheduler.free_finished_seq_groups() scheduler.free_finished_seq_groups()
# For multi-step, do not create outputs each iteration # For multi-step without streaming, don't create outputs each iteration
if not is_last_step: if not is_last_step and not ctx.multi_step_stream_outputs:
# Immediately process request outputs here (if callback is given) # Immediately process request outputs here (if callback is given)
if (finished_now if (finished_now
and self.process_request_outputs_callback is not None): and self.process_request_outputs_callback is not None):
...@@ -1040,17 +1063,27 @@ class LLMEngine: ...@@ -1040,17 +1063,27 @@ class LLMEngine:
seq_group = scheduled_seq_group.seq_group seq_group = scheduled_seq_group.seq_group
seq_group.maybe_set_first_token_time(now) seq_group.maybe_set_first_token_time(now)
request_output = RequestOutputFactory.create(seq_group) request_output = RequestOutputFactory.create(
seq_group, use_cache=self.use_cached_outputs)
if request_output: if request_output:
ctx.request_outputs.append(request_output) ctx.request_outputs.append(request_output)
# For multi-step with streaming, create outputs each iteration
if not is_last_step and ctx.multi_step_stream_outputs:
# Immediately process request outputs here (if callback is given)
if self.process_request_outputs_callback is not None:
self.process_request_outputs_callback(ctx.request_outputs)
ctx.request_outputs.clear()
return
for seq_group in scheduler_outputs.ignored_seq_groups: for seq_group in scheduler_outputs.ignored_seq_groups:
params = seq_group.sampling_params params = seq_group.sampling_params
if params is not None and params.output_kind == ( if params is not None and params.output_kind == (
RequestOutputKind.DELTA) and not seq_group.is_finished(): RequestOutputKind.DELTA) and not seq_group.is_finished():
continue continue
request_output = RequestOutputFactory.create(seq_group) request_output = RequestOutputFactory.create(
seq_group, use_cache=self.use_cached_outputs)
if request_output: if request_output:
ctx.request_outputs.append(request_output) ctx.request_outputs.append(request_output)
...@@ -1292,6 +1325,7 @@ class LLMEngine: ...@@ -1292,6 +1325,7 @@ class LLMEngine:
# torch.distributed ops which may otherwise timeout, and unblocks # torch.distributed ops which may otherwise timeout, and unblocks
# the RPC thread in the workers so that they can process any other # the RPC thread in the workers so that they can process any other
# queued control plane messages, such as add/remove lora adapters. # queued control plane messages, such as add/remove lora adapters.
logger.debug("Stopping remote worker execution loop.")
self.model_executor.stop_remote_worker_execution_loop() self.model_executor.stop_remote_worker_execution_loop()
return ctx.request_outputs return ctx.request_outputs
...@@ -1608,7 +1642,7 @@ class LLMEngine: ...@@ -1608,7 +1642,7 @@ class LLMEngine:
def start_profile(self) -> None: def start_profile(self) -> None:
# using type instead of isinstance to check to avoid capturing # using type instead of isinstance to check to avoid capturing
# inherited classes (MultiprocessingGPUExecutor) # inherited classes (MultiprocessingGPUExecutor)
if type(self.model_executor) == GPUExecutor: if type(self.model_executor) == GPUExecutor: # noqa: E721
self.model_executor.start_profile() self.model_executor.start_profile()
else: else:
self.model_executor._run_workers("start_profile") self.model_executor._run_workers("start_profile")
...@@ -1616,7 +1650,7 @@ class LLMEngine: ...@@ -1616,7 +1650,7 @@ class LLMEngine:
def stop_profile(self) -> None: def stop_profile(self) -> None:
# using type instead of isinstance to check to avoid capturing # using type instead of isinstance to check to avoid capturing
# inherited classes (MultiprocessingGPUExecutor) # inherited classes (MultiprocessingGPUExecutor)
if type(self.model_executor) == GPUExecutor: if type(self.model_executor) == GPUExecutor: # noqa: E721
self.model_executor.stop_profile() self.model_executor.stop_profile()
else: else:
self.model_executor._run_workers("stop_profile") self.model_executor._run_workers("stop_profile")
...@@ -1700,7 +1734,11 @@ class LLMEngine: ...@@ -1700,7 +1734,11 @@ class LLMEngine:
def _validate_model_inputs(self, inputs: Union[LLMInputs, def _validate_model_inputs(self, inputs: Union[LLMInputs,
EncoderDecoderLLMInputs]): EncoderDecoderLLMInputs]):
if self.is_encoder_decoder_model(): if self.model_config.is_multimodal_model:
# For encoder-decoder multimodal models, the max_prompt_len
# restricts the decoder prompt length
prompt_ids = inputs.get("prompt_token_ids")
elif self.is_encoder_decoder_model():
prompt_ids = inputs.get("encoder_prompt_token_ids") prompt_ids = inputs.get("encoder_prompt_token_ids")
else: else:
prompt_ids = inputs.get("prompt_token_ids") prompt_ids = inputs.get("prompt_token_ids")
......
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
from typing import Mapping, Optional, Union from typing import List, Mapping, Optional, Union
from vllm import PoolingParams
from vllm.inputs import PromptInputs from vllm.inputs import PromptInputs
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
# Success string used for RPC instructions.
VLLM_RPC_SUCCESS_STR = "SUCCESS" VLLM_RPC_SUCCESS_STR = "SUCCESS"
# Minimum value of ZMQ.SOCKET_LIMIT to run mp. IPC_INPUT_EXT = "_input_socket"
VLLM_RPC_SOCKET_LIMIT_CUTOFF = 2000 IPC_OUTPUT_EXT = "_output_socket"
IPC_HEALTH_EXT = "_health_socket"
IPC_DATA_EXT = "_data_socket"
# HWM is set to Infinity.
VLLM_RPC_ZMQ_HWM = 0 class MQEngineDeadError(RuntimeError):
pass
@dataclass @dataclass
class RPCGenerateRequest: class RPCProcessRequest:
inputs: PromptInputs inputs: PromptInputs
sampling_params: SamplingParams params: Union[SamplingParams, PoolingParams]
request_id: str request_id: str
lora_request: Optional[LoRARequest] = None lora_request: Optional[LoRARequest] = None
trace_headers: Optional[Mapping[str, str]] = None trace_headers: Optional[Mapping[str, str]] = None
prompt_adapter_request: Optional[PromptAdapterRequest] = None prompt_adapter_request: Optional[PromptAdapterRequest] = None
@dataclass
class RPCError:
request_id: Optional[str]
is_engine_errored: bool
exception: BaseException
@dataclass @dataclass
class RPCAbortRequest: class RPCAbortRequest:
request_id: str request_id: str
class RPCUtilityRequest(Enum): class RPCStartupRequest(Enum):
IS_SERVER_READY = 1 IS_SERVER_READY = 1
GET_MODEL_CONFIG = 2
GET_DECODING_CONFIG = 3
GET_PARALLEL_CONFIG = 4 @dataclass
GET_SCHEDULER_CONFIG = 5 class RPCStartupResponse:
GET_LORA_CONFIG = 6 tracing_enabled: bool
DO_LOG_STATS = 7
IS_SERVER_HEALTHY = 8
IS_TRACING_ENABLED = 9 class RPCUProfileRequest(Enum):
START_PROFILE = 10 START_PROFILE = 1
STOP_PROFILE = 11 STOP_PROFILE = 2
RPC_REQUEST_TYPE = Union[RPCGenerateRequest, RPCAbortRequest, RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCStartupRequest,
RPCUtilityRequest] RPCUProfileRequest]
REQUEST_OUTPUTS_T = Union[List[RequestOutput], RPCError]
def ENGINE_DEAD_ERROR(
error: Optional[BaseException] = None) -> MQEngineDeadError:
if error is None:
return MQEngineDeadError(
"Engine loop is not running. Inspect the stacktrace to "
"find the original error")
return MQEngineDeadError(
"Engine loop is not running. Inspect the stacktrace to "
f"find the original error: {repr(error)}.")
import asyncio
import copy
import pickle
from contextlib import contextmanager, suppress
from typing import (Any, AsyncGenerator, Dict, Iterator, Mapping, Optional,
Union)
import cloudpickle
import zmq
import zmq.asyncio
from zmq import Frame # type: ignore[attr-defined]
from zmq.asyncio import Socket
from vllm import PoolingParams
from vllm.config import DecodingConfig, EngineConfig, ModelConfig
from vllm.engine.arg_utils import AsyncEngineArgs
# yapf conflicts with isort for this block
# yapf: disable
from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
IPC_HEALTH_EXT, IPC_INPUT_EXT,
IPC_OUTPUT_EXT, RPC_REQUEST_T,
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
RPCError, RPCProcessRequest,
RPCStartupRequest, RPCStartupResponse,
RPCUProfileRequest)
# yapf: enable
from vllm.envs import VLLM_RPC_TIMEOUT
from vllm.inputs import PromptInputs
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
logger = init_logger(__name__)
class MQClientClosedError(Exception):
"""Exception class raised when the client is used post-close.
The client can be closed, which closes the ZMQ context. This normally
happens on server shutdown. In some cases, methods like abort and
do_log_stats will still be called and then try to open a socket, which
causes a ZMQError and creates a huge stack trace.
So, we throw this error such that we can suppress it.
"""
class MQLLMEngineClient:
"""A client wrapper for MQLLMEngine that conforms to the
EngineClient protocol.
MQLLMEngine and MQLLMEngineClient are intended to run in separate
processes communicating via zeromq ipc sockets.
The entrypoint to MQLLMEngineClient is through the generate()
method. On generate() MQLLMEngine does three things:
- Creates an asyncio output queue
- Sends a RPCGenerateRequest to the MQLLMEngine via zmq
- Pulls RequestOutputs from its queue and yields them
MQLLMEngine runs two background loops:
- output_loop: the output loop pulls List[RequestOutput]
from the MQLLMEngine via zmq (each list is the output
of one engine_step in the LLMEngine). It then parses
the list and pushes individual request_outputs into
the corresponding output_queue such that they can be
consumed by the .generate() method.
- health_loop: the health loop queries the health socket
every N seconds, confirming the engine is healthy
"""
def __init__(self, ipc_path: str, engine_config: EngineConfig):
self.context = zmq.asyncio.Context()
self._errored_with: Optional[BaseException] = None
# Get the configs.
self.model_config = engine_config.model_config
self.decoding_config = engine_config.decoding_config
# Create the tokenizer group.
self.tokenizer = init_tokenizer_from_configs(
model_config=self.model_config,
scheduler_config=engine_config.scheduler_config,
parallel_config=engine_config.parallel_config,
enable_lora=bool(engine_config.lora_config),
)
# Send RPCGenerateRequest to the MQLLMEngine.
self.input_socket: Socket = self.context.socket(zmq.constants.PUSH)
self.input_socket.connect(f"{ipc_path}{IPC_INPUT_EXT}")
# Receive streams of RequestOutput from the MQLLMEngine.
self.output_socket: Socket = self.context.socket(zmq.constants.PULL)
self.output_socket.connect(f"{ipc_path}{IPC_OUTPUT_EXT}")
# IPC path for acking heartbeats.
self.heartbeat_socket: Socket = self.context.socket(zmq.constants.PULL)
self.heartbeat_socket.connect(f"{ipc_path}{IPC_HEALTH_EXT}")
# IPC path for the data socket.
self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}"
# Stream for each individual request.
self.output_queues: Dict[str, asyncio.Queue] = {}
self.output_loop = asyncio.create_task(self.run_output_handler_loop())
# Loop to check health of the LLMEngine periodically.
# Started after the MQLLMEngine is ready.
self.health_loop: Optional[asyncio.Task] = None
@staticmethod
def is_unsupported_config(engine_args: AsyncEngineArgs):
# Pipeline parallel not yet supported
return engine_args.pipeline_parallel_size > 1
@contextmanager
def get_data_socket(self) -> Iterator[Socket]:
socket = self.context.socket(zmq.constants.DEALER)
try:
socket.connect(self.data_ipc_path)
yield socket
finally:
socket.close(linger=0)
async def run_heartbeat_loop(self, timeout: int):
"""Background loop that continually listens to the RPCServer for
heartbeats.
"""
try:
while True:
if await self.heartbeat_socket.poll(timeout=timeout) == 0:
# No heartbeat was received. Set error and exit the loop
self._set_errored(
TimeoutError("No heartbeat received "
"from MQLLMEngine"))
logger.debug("Shutting down MQLLMEngineClient check "
"health loop due to timeout")
break
else:
# Heartbeat received- check the message
await self._check_success(
error_message="Heartbeat failed.",
socket=self.heartbeat_socket)
logger.debug("Heartbeat successful.")
except asyncio.CancelledError:
logger.debug("Shutting down MQLLMEngineClient check health loop.")
except Exception as e:
self._set_errored(e)
async def run_output_handler_loop(self):
"""Get RequestOutputs from Engine and stream to Request Queues"""
try:
while True:
# Poll, checking for ENGINE_DEAD
while await self.output_socket.poll(timeout=VLLM_RPC_TIMEOUT
) == 0:
logger.debug("Waiting for output from MQLLMEngine.")
# If errored, alert all running requests.
if self.errored:
for queue_j in tuple(self.output_queues.values()):
queue_j.put_nowait(
ENGINE_DEAD_ERROR(self._errored_with))
return
message: Frame = await self.output_socket.recv(copy=False)
request_outputs = pickle.loads(message.buffer)
is_error = isinstance(request_outputs,
(BaseException, RPCError))
if is_error:
if isinstance(request_outputs, RPCError):
rpc_error: RPCError = request_outputs
request_id = rpc_error.request_id
exception = rpc_error.exception
is_engine_errored = rpc_error.is_engine_errored
else:
# MPLLMEngine should always return an RPCError to
# the output_socket when an issue arises.
# If we are here, we are in a bad state and
# should shut down the server.
error: BaseException = request_outputs
logger.error(
"Received Exception %s rather than RPCError from "
"MPLLMEngine. This should never happen.", error)
request_id = None
exception = error
is_engine_errored = True
# Set to error state only on engine critical error
# (and record only the first one)
if is_engine_errored and not self._errored_with:
self._errored_with = exception
if request_id is None:
for queue_i in tuple(self.output_queues.values()):
queue_i.put_nowait(exception)
else:
queue = self.output_queues.get(request_id)
if queue is not None:
queue.put_nowait(exception)
else:
# Put each output into the appropriate steam.
for request_output in request_outputs:
queue = self.output_queues.get(
request_output.request_id)
if queue is not None:
queue.put_nowait(request_output)
except asyncio.CancelledError:
logger.debug("Shutting down MQLLMEngineClient output handler.")
async def setup(self):
"""Setup the client before it starts sending server requests."""
with self.get_data_socket() as socket:
# Wait until server is ready.
response = await self._wait_for_server_rpc(socket)
self.tracing_flag = response.tracing_enabled
# Start health_loop.
self.health_loop = asyncio.create_task(
self.run_heartbeat_loop(timeout=VLLM_RPC_TIMEOUT))
def close(self):
"""Destroy the ZeroMQ Context."""
# Close all sockets and terminate the context.
self.context.destroy(linger=0)
# Cancel background tasks.
if self.health_loop is not None:
self.health_loop.cancel()
self.output_loop.cancel()
def _set_errored(self, e: BaseException):
logger.exception(repr(e))
if self._errored_with is None:
self._errored_with = e
@staticmethod
async def _send_get_data_rpc_request(request: RPCStartupRequest,
expected_type: Any,
error_message: str,
socket: Socket) -> Any:
"""Send an RPC request that is expecting data back."""
# Ping RPCServer with a request.
await socket.send_multipart((pickle.dumps(request), ), copy=False)
# Make sure the server responds in time.
if await socket.poll(timeout=VLLM_RPC_TIMEOUT) == 0:
raise TimeoutError("RPCServer didn't reply within "
f"{VLLM_RPC_TIMEOUT} ms")
# Await the data from the Server.
frame = await socket.recv(copy=False)
data = pickle.loads(frame.buffer)
if isinstance(data, BaseException):
raise data
elif not isinstance(data, expected_type):
raise ValueError(error_message)
return data
@staticmethod
async def _send_one_way_rpc_request(request: RPC_REQUEST_T,
socket: Socket):
"""Send one-way RPC request to trigger an action."""
if socket.closed:
raise MQClientClosedError()
await socket.send_multipart((pickle.dumps(request), ))
async def _await_ack(self, error_message: str, socket: Socket):
"""Await acknowledgement that a request succeeded."""
if socket.closed:
raise MQClientClosedError()
if await socket.poll(timeout=VLLM_RPC_TIMEOUT) == 0:
raise TimeoutError("MQLLMEngine didn't reply within "
f"{VLLM_RPC_TIMEOUT}ms")
await self._check_success(error_message, socket)
@staticmethod
async def _check_success(error_message: str, socket: Socket):
"""Confirm that socket has a VLLM_RPC_SUCCESS_STR message"""
if socket.closed:
raise MQClientClosedError()
frame = await socket.recv(copy=False)
response = pickle.loads(frame.buffer)
# Raise error if unsuccessful
if isinstance(response, BaseException):
raise response
elif (not isinstance(response, str)
or response != VLLM_RPC_SUCCESS_STR):
raise ValueError(error_message)
async def get_tokenizer(self, lora_request: LoRARequest):
return await self.tokenizer.get_lora_tokenizer_async(lora_request)
async def get_decoding_config(self) -> DecodingConfig:
return self.decoding_config
async def get_model_config(self) -> ModelConfig:
return self.model_config
async def is_tracing_enabled(self) -> bool:
return self.tracing_flag
async def _wait_for_server_rpc(self, socket: Socket) -> RPCStartupResponse:
"""Wait for the RPCServer to start up."""
return await self._send_get_data_rpc_request(
request=RPCStartupRequest.IS_SERVER_READY,
expected_type=RPCStartupResponse,
error_message="Unable to start RPC Server",
socket=socket)
async def abort(self, request_id: str):
"""Send an ABORT_REQUEST signal to the RPC Server"""
with suppress(MQClientClosedError):
await self._send_one_way_rpc_request(
request=RPCAbortRequest(request_id), socket=self.input_socket)
async def do_log_stats(self):
"""Ignore do_log_stats (handled on MQLLMEngine polling)"""
pass
async def check_health(self):
"""
The check health loop probes the health status of the
Engine's health every N seconds and sets _errored_with
if the engine is unhealthy.
"""
if self._errored_with is not None:
raise self._errored_with
@property
def is_running(self) -> bool:
return not self.errored
@property
def is_stopped(self) -> bool:
return self.errored
@property
def errored(self) -> bool:
return self._errored_with is not None
@property
def dead_error(self) -> BaseException:
return ENGINE_DEAD_ERROR(self._errored_with)
def generate(
self,
inputs: PromptInputs,
sampling_params: SamplingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
) -> AsyncGenerator[RequestOutput, None]:
"""Generate outputs for a request.
Generate outputs for a request. This method is a coroutine. It adds the
request into the waiting queue of the LLMEngine and streams the outputs
from the LLMEngine to the caller.
Args:
inputs: The inputs to the LLM. See
:class:`~vllm.inputs.PromptInputs`
for more details about the format of each input.
sampling_params: The sampling parameters of the request.
request_id: The unique id of the request.
lora_request: LoRA request to use for generation, if any.
trace_headers: OpenTelemetry trace headers.
prompt_adapter_request: Prompt Adapter request to use
for generation, if any.
"""
return self._process_request(inputs, sampling_params, request_id,
lora_request, trace_headers,
prompt_adapter_request)
def encode(
self,
inputs: PromptInputs,
pooling_params: PoolingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
"""Generate outputs for a request from an embedding model.
Generate outputs for a request. This method is a coroutine. It adds the
request into the waiting queue of the LLMEngine and streams the outputs
from the LLMEngine to the caller.
Args:
inputs: The inputs to the LLM. See
:class:`~vllm.inputs.PromptInputs`
for more details about the format of each input.
pooling_params: The pooling parameters of the request.
request_id: The unique id of the request.
lora_request: LoRA request to use for generation, if any.
trace_headers: OpenTelemetry trace headers.
Yields:
The output `EmbeddingRequestOutput` objects from the LLMEngine
for the request.
"""
return self._process_request(inputs, pooling_params, request_id,
lora_request, trace_headers)
async def _process_request(
self,
inputs: PromptInputs,
params: Union[SamplingParams, PoolingParams],
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
) -> Union[AsyncGenerator[RequestOutput, None], AsyncGenerator[
EmbeddingRequestOutput, None]]:
"""Send an RPCGenerateRequest to the RPCServer and stream responses."""
# If already dead, error out.
if self._errored_with is not None:
raise ENGINE_DEAD_ERROR(self._errored_with)
# 1) Create output queue for this requests.
queue: asyncio.Queue[Union[RequestOutput,
BaseException]] = asyncio.Queue()
self.output_queues[request_id] = queue
try:
# 2) Detach logits processors so that they can be pickled
# separately (may require cloudpickle which is slower)
if isinstance(params, SamplingParams) and params.logits_processors:
# Defensive shallow copy
params = copy.copy(params)
logits_processors = params.logits_processors
params.logits_processors = None
lp_bytes = cloudpickle.dumps(logits_processors)
else:
lp_bytes = None
request_bytes = pickle.dumps(
RPCProcessRequest(
inputs=inputs,
params=params,
request_id=request_id,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request))
# 3) Send the RPCGenerateRequest to the MQLLMEngine.
parts = (request_bytes,
lp_bytes) if lp_bytes else (request_bytes, )
await self.input_socket.send_multipart(parts, copy=False)
# 4) Stream the RequestOutputs from the output queue. Note
# that the output_loop pushes RequestOutput objects to this
# queue after pulling them from the zmq socket.
finished = False
try:
while not finished:
request_output = await queue.get()
if isinstance(request_output, BaseException):
raise request_output
finished = request_output.finished
yield request_output
finally:
# Request was canceled by the client.
if not finished and not self.errored:
await self.abort(request_id)
finally:
self.output_queues.pop(request_id)
async def start_profile(self) -> None:
"""Start profiling the engine"""
await self._send_one_way_rpc_request(
request=RPCUProfileRequest.START_PROFILE, socket=self.input_socket)
async def stop_profile(self) -> None:
"""Stop profiling the engine"""
await self._send_one_way_rpc_request(
request=RPCUProfileRequest.STOP_PROFILE, socket=self.input_socket)
import pickle
import signal
import threading
import time
from contextlib import contextmanager
from typing import Iterator, List, Optional, Union
import cloudpickle
import zmq
from vllm import AsyncEngineArgs, LLMEngine, SamplingParams
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig)
# yapf conflicts with isort for this block
# yapf: disable
from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
IPC_HEALTH_EXT, IPC_INPUT_EXT,
IPC_OUTPUT_EXT, REQUEST_OUTPUTS_T,
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
RPCError, RPCProcessRequest,
RPCStartupRequest, RPCStartupResponse,
RPCUProfileRequest)
# yapf: enable
from vllm.envs import VLLM_RPC_TIMEOUT
from vllm.executor.gpu_executor import GPUExecutor
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.usage.usage_lib import UsageContext
CONFIG_TYPE = Union[ModelConfig, DecodingConfig, ParallelConfig,
SchedulerConfig, LoRAConfig]
logger = init_logger(__name__)
POLLING_TIMEOUT_MS = 10000
HEALTHY_RESPONSE = (pickle.dumps(VLLM_RPC_SUCCESS_STR), )
class MQLLMEngine:
"""A multiprocessing wrapper for :class:`LLMEngine`.
This class is used to wrap the :class:`LLMEngine` class to enable use
in concurrnet manner. It runs a background loop and uses zeromq to
receive new requests and stream outputs incrementally via ipc.
The :class:`LLMEngine` generate or encode process is kicked off when a new
RPCProcessRequest is received by the input_socket.
The self.engine_loop checks the input_socket for new requests,
adds them to the LLMEngine if there are any, calls the internal
:class:`LLMEngine.step()`, and sends the RequestOutputs back over
the output_socket.
If use_async_sockets is set, the logic associated with reading new
requests from the socket and sending data to the socket is passed
as a callback to the llm_engine, which calls the logic asynchronously
such that the IPC can be overlapped with the GPU.
Args:
ipc_path: Base path for zeromq interprocess messaging
use_async_sockets: Whether to make send/recv async with GPU
log_requests: Whether to log the requests.
*args: Arguments for :class:`LLMEngine`.
**kwargs: Arguments for :class:`LLMEngine`.
"""
def __init__(self,
ipc_path: str,
use_async_sockets: bool,
*args,
log_requests: bool = True,
**kwargs) -> None:
# For MQLLMEngine, we can use cached outputs, since each new request
# output is immediately pickled and send over the socket, which frees
# the python object to be reused again.
use_cached_outputs = True
self.engine = LLMEngine(*args,
**kwargs,
use_cached_outputs=use_cached_outputs)
self.log_requests = log_requests
self.use_async_sockets = use_async_sockets
if self.use_async_sockets:
self.engine.process_request_outputs_callback = \
self._async_socket_engine_callback
self.ctx = zmq.Context() # type: ignore[attr-defined]
# Receive input from the client.
self.input_socket = self.ctx.socket(zmq.constants.PULL)
self.input_socket.bind(f"{ipc_path}{IPC_INPUT_EXT}")
# Send output stream back to client.
self.output_socket = self.ctx.socket(zmq.constants.PUSH)
self.output_socket.bind(f"{ipc_path}{IPC_OUTPUT_EXT}")
# Send heartbeats back to client.
self.heartbeat_socket = self.ctx.socket(zmq.constants.PUSH)
self.heartbeat_socket.bind(f"{ipc_path}{IPC_HEALTH_EXT}")
# IPC path for the data socket.
self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}"
# Error state.
self._errored_with: Optional[BaseException] = None
# Heartbeat thread
self.heartbeat_thread = threading.Thread(target=self._heartbeat_loop,
daemon=True)
self._heartbeat_stop_event = threading.Event()
# The heartbeat needs to be faster than what the client will wait for
# The VLLM_RPC_TIMEOUT duration is in ms, and we need one in seconds
self.heartbeat_interval_seconds = VLLM_RPC_TIMEOUT / 5000.0
self._last_alive_time = time.time()
# The heartbeats can tolerate a long period of the engine chugging
# away at a generation request.
# The VLLM_RPC_TIMEOUT duration is in ms, and we need one in seconds
self.last_alive_threshold = VLLM_RPC_TIMEOUT * 3.0 / 1000.0
@property
def dead_error(self) -> BaseException:
if self._errored_with is not None:
return ENGINE_DEAD_ERROR(self._errored_with)
else:
return ENGINE_DEAD_ERROR()
@classmethod
def from_engine_args(cls, engine_args: AsyncEngineArgs,
usage_context: UsageContext, ipc_path: str):
"""Creates an MQLLMEngine from the engine arguments."""
engine_config = engine_args.create_engine_config()
executor_class = LLMEngine._get_executor_cls(engine_config)
return cls(
ipc_path=ipc_path,
use_async_sockets=engine_config.model_config.use_async_output_proc,
**engine_config.to_dict(),
executor_class=executor_class,
log_requests=not engine_args.disable_log_requests,
log_stats=not engine_args.disable_log_stats,
usage_context=usage_context)
def start(self):
try:
try:
logger.debug("Starting Startup Loop.")
self.run_startup_loop()
logger.debug("Starting heartbeat thread")
self.heartbeat_thread.start()
logger.debug("Starting Engine Loop.")
self.run_engine_loop()
except Exception as e:
logger.exception(repr(e))
except KeyboardInterrupt:
logger.debug("Shutting down MQLLMEngine.")
finally:
logger.debug("MQLLMEngine is shut down.")
self.cleanup()
def cleanup(self):
"""Cleanup zeromq state on shutdown."""
# Closes all sockets and destroys context.
self._heartbeat_stop_event.set()
self.ctx.destroy(linger=0)
del self.engine
@contextmanager
def make_data_socket(
self) -> Iterator[zmq.Socket]: # type: ignore[name-defined]
socket = self.ctx.socket(zmq.constants.ROUTER)
try:
socket.bind(self.data_ipc_path)
yield socket
finally:
socket.close(linger=0)
def run_startup_loop(self) -> None:
"""Startup loop for sending data from Engine -> Client."""
with self.make_data_socket() as socket:
response: Union[RPCStartupResponse, BaseException]
try:
identity, message = socket.recv_multipart(copy=False)
request: RPCStartupRequest = pickle.loads(message.buffer)
# Handle the query from the Client.
if request == RPCStartupRequest.IS_SERVER_READY:
tracing_enabled = self.engine.is_tracing_enabled()
response = RPCStartupResponse(
tracing_enabled=tracing_enabled)
except Exception as e:
response = e
socket.send_multipart((identity, pickle.dumps(response)),
copy=False)
def run_engine_loop(self):
"""Core busy loop of the LLMEngine."""
while True:
self._alive()
if not self.engine.has_unfinished_requests():
# Poll until there is work to do.
while self.input_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0:
self._alive()
self.engine.do_log_stats()
logger.debug("Waiting for new requests in engine loop.")
# Handle any input from the client.
self.handle_new_input()
# Engine step.
request_outputs = self.engine_step()
# Send request outputs (if async, done in engine_step callback).
if not self.use_async_sockets:
self._send_outputs(request_outputs)
def engine_step(self) -> List[RequestOutput]:
"""Engine step wrapper with error handling."""
try:
return self.engine.step()
except SystemExit:
raise
except BaseException as e:
self._set_errored(e)
rpc_err = RPCError(request_id=None,
is_engine_errored=True,
exception=e)
self._send_outputs(rpc_err)
raise e
def handle_new_input(self):
"""Handle new input from the socket"""
try:
while self.input_socket.poll(timeout=0) != 0:
frames = self.input_socket.recv_multipart(copy=False)
request = pickle.loads(frames[0].buffer)
if isinstance(request, RPCProcessRequest):
if len(frames) > 1:
# Use cloudpickle for logits processors
assert isinstance(request.params, SamplingParams)
lprocs = cloudpickle.loads(frames[1].buffer)
request.params.logits_processors = lprocs
self._handle_process_request(request)
elif isinstance(request, RPCAbortRequest):
self._handle_abort_request(request)
elif isinstance(request, RPCUProfileRequest):
if request == RPCUProfileRequest.START_PROFILE:
self.start_profile()
else:
self.stop_profile()
else:
raise ValueError("Unknown RPCRequest Type: "
f"{type(request)}")
except Exception as e:
self._set_errored(e)
self._send_unhealthy(e)
raise e
def _handle_process_request(self, request: RPCProcessRequest):
"""Handle RPCProcessRequest by adding it to the LLMEngine."""
request_id = request.request_id
if self._errored_with is not None:
rpc_err = RPCError(request_id=request_id,
is_engine_errored=True,
exception=ENGINE_DEAD_ERROR(self._errored_with))
self._send_outputs(rpc_err)
try:
self.engine.add_request(
request_id=request_id,
inputs=request.inputs,
params=request.params,
lora_request=request.lora_request,
trace_headers=request.trace_headers,
prompt_adapter_request=request.prompt_adapter_request)
if self.log_requests:
logger.info("Added request %s.", request.request_id)
except Exception as e:
# We do not set self._errored = True here, since the error
# is due to an issue adding this request to the engine,
# rather than an issue with the engine itself.
is_errored = self._errored_with is not None
rpc_err = RPCError(request_id=request_id,
is_engine_errored=is_errored,
exception=e)
self._send_outputs(rpc_err)
# Remove request from the engine.
self.engine.abort_request(request_id)
def _handle_abort_request(self, request: RPCAbortRequest):
self.engine.abort_request(request.request_id)
if self.log_requests:
logger.info("Aborted request %s.", request.request_id)
def _heartbeat_loop(self):
while not self._heartbeat_stop_event.wait(
timeout=self.heartbeat_interval_seconds):
# Loops until the stop event is set
self._heartbeat()
logger.debug("Exiting MQLLMEngine heartbeat thread")
def _heartbeat(self):
# Send unhealthy if engine has already errored
if self._errored_with is not None:
self._send_unhealthy(self._errored_with)
# Check for life of the main loop
elif time.time() - self._last_alive_time > self.last_alive_threshold:
self._send_unhealthy(RuntimeError("Engine loop has died"))
else:
# Otherwise- check health of the engine
# self.engine.check_health() raises on unhealthy
try:
self.engine.check_health()
self._send_healthy()
except Exception as e:
self._set_errored(e)
self._send_unhealthy(e)
def _send_outputs(self, outputs: REQUEST_OUTPUTS_T):
"""Send List of RequestOutput to RPCClient."""
if outputs:
output_bytes = pickle.dumps(outputs)
self.output_socket.send_multipart((output_bytes, ), copy=False)
def _send_healthy(self):
"""Send HEALTHY message to RPCClient."""
if not self.heartbeat_socket.closed:
self.heartbeat_socket.send_multipart(HEALTHY_RESPONSE, copy=False)
def _send_unhealthy(self, error: BaseException):
"""Send UNHEALTHY message to RPCClient."""
if not self.heartbeat_socket.closed:
error_bytes = pickle.dumps(error)
self.heartbeat_socket.send_multipart((error_bytes, ), copy=False)
def _async_socket_engine_callback(self,
request_outputs: REQUEST_OUTPUTS_T):
"""Callback used by engine to make socket handling async with GPU."""
self._send_outputs(request_outputs)
self.handle_new_input()
def _set_errored(self, e: BaseException):
"""Log and set errored status if this is the first issue."""
if self._errored_with is None:
self._errored_with = e
def _alive(self):
self._last_alive_time = time.time()
def start_profile(self) -> None:
if type(self.engine.model_executor) is GPUExecutor:
self.engine.model_executor.start_profile()
else:
self.engine.model_executor._run_workers("start_profile")
def stop_profile(self) -> None:
if type(self.engine.model_executor) is GPUExecutor:
self.engine.model_executor.stop_profile()
else:
self.engine.model_executor._run_workers("stop_profile")
def run_mp_engine(engine_args: AsyncEngineArgs, usage_context: UsageContext,
ipc_path: str):
def signal_handler(*_) -> None:
# Interrupt server on sigterm
raise KeyboardInterrupt("MQLLMEngine terminated")
signal.signal(signal.SIGTERM, signal_handler)
engine = MQLLMEngine.from_engine_args(engine_args=engine_args,
usage_context=usage_context,
ipc_path=ipc_path)
engine.start()
...@@ -9,8 +9,8 @@ from vllm.engine.output_processor.single_step import ( ...@@ -9,8 +9,8 @@ from vllm.engine.output_processor.single_step import (
from vllm.engine.output_processor.stop_checker import StopChecker from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import (Sequence, SequenceGroup, SequenceGroupOutput, from vllm.sequence import (VLLM_INVALID_TOKEN_ID, Sequence, SequenceGroup,
SequenceOutput, SequenceStatus) SequenceGroupOutput, SequenceOutput, SequenceStatus)
from vllm.transformers_utils.detokenizer import Detokenizer from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import Counter from vllm.utils import Counter
...@@ -110,10 +110,11 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor): ...@@ -110,10 +110,11 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
# we can take the first sample. # we can take the first sample.
samples = [output.samples[0] for output in outputs] samples = [output.samples[0] for output in outputs]
# -1 means the output token is not valid (eg. due to spec decode # entries in sample tokens may be invalid (eg. due to spec decode
# rejecting tokens). # rejecting tokens).
valid_samples = [ valid_samples = [
sample for sample in samples if sample.output_token != -1 sample for sample in samples
if sample.output_token != VLLM_INVALID_TOKEN_ID
] ]
assert valid_samples assert valid_samples
......
...@@ -14,8 +14,8 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer ...@@ -14,8 +14,8 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer
@runtime_checkable @runtime_checkable
class AsyncEngineClient(Protocol): class EngineClient(Protocol):
"""Protocol class for Clients to AsyncLLMEngine""" """Protocol class for Clients to Engine"""
@property @property
def is_running(self) -> bool: def is_running(self) -> bool:
...@@ -30,8 +30,8 @@ class AsyncEngineClient(Protocol): ...@@ -30,8 +30,8 @@ class AsyncEngineClient(Protocol):
... ...
@property @property
def limit_concurrency(self) -> Optional[int]: def dead_error(self) -> BaseException:
"""Maximum number of concurrently running requests.""" ...
def generate( def generate(
self, self,
......
...@@ -121,7 +121,6 @@ async def run_server(args: Namespace, ...@@ -121,7 +121,6 @@ async def run_server(args: Namespace,
shutdown_task = await serve_http( shutdown_task = await serve_http(
app, app,
engine=engine,
host=args.host, host=args.host,
port=args.port, port=args.port,
log_level=args.log_level, log_level=args.log_level,
......
...@@ -159,6 +159,8 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): ...@@ -159,6 +159,8 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
hf_config.image_token_index) hf_config.image_token_index)
if model_type in ("chameleon", "internvl_chat"): if model_type in ("chameleon", "internvl_chat"):
return "<image>" return "<image>"
if model_type == "mllama":
return "<|image|>"
if model_type == "qwen2_vl": if model_type == "qwen2_vl":
return "<|vision_start|><|image_pad|><|vision_end|>" return "<|vision_start|><|image_pad|><|vision_end|>"
...@@ -358,6 +360,7 @@ _TextParser = partial(cast, ChatCompletionContentPartTextParam) ...@@ -358,6 +360,7 @@ _TextParser = partial(cast, ChatCompletionContentPartTextParam)
_ImageParser = partial(cast, ChatCompletionContentPartImageParam) _ImageParser = partial(cast, ChatCompletionContentPartImageParam)
_AudioParser = partial(cast, ChatCompletionContentPartAudioParam) _AudioParser = partial(cast, ChatCompletionContentPartAudioParam)
_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam) _RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam)
MODEL_KEEP_MULTI_MODAL_CONTENT = {'mllama'}
def _parse_chat_message_content_parts( def _parse_chat_message_content_parts(
...@@ -368,7 +371,11 @@ def _parse_chat_message_content_parts( ...@@ -368,7 +371,11 @@ def _parse_chat_message_content_parts(
texts: List[str] = [] texts: List[str] = []
mm_parser = mm_tracker.create_parser() mm_parser = mm_tracker.create_parser()
keep_multimodal_content = \
mm_tracker._model_config.hf_config.model_type in \
MODEL_KEEP_MULTI_MODAL_CONTENT
has_image = False
for part in parts: for part in parts:
part_type = part["type"] part_type = part["type"]
if part_type == "text": if part_type == "text":
...@@ -383,6 +390,7 @@ def _parse_chat_message_content_parts( ...@@ -383,6 +390,7 @@ def _parse_chat_message_content_parts(
"will be ignored.") "will be ignored.")
mm_parser.parse_image(image_url["url"]) mm_parser.parse_image(image_url["url"])
has_image = True
elif part_type == "audio_url": elif part_type == "audio_url":
audio_url = _AudioParser(part)["audio_url"] audio_url = _AudioParser(part)["audio_url"]
...@@ -394,12 +402,20 @@ def _parse_chat_message_content_parts( ...@@ -394,12 +402,20 @@ def _parse_chat_message_content_parts(
raise NotImplementedError(f"Unknown part type: {part_type}") raise NotImplementedError(f"Unknown part type: {part_type}")
text_prompt = "\n".join(texts) text_prompt = "\n".join(texts)
mm_placeholder_counts = mm_parser.mm_placeholder_counts() if keep_multimodal_content:
if mm_placeholder_counts: text_prompt = "\n".join(texts)
text_prompt = _get_full_multimodal_text_prompt(mm_placeholder_counts, role_content = [{'type': 'text', 'text': text_prompt}]
text_prompt)
if has_image:
return [ConversationMessage(role=role, content=text_prompt)] role_content = [{'type': 'image'}] + role_content
return [ConversationMessage(role=role,
content=role_content)] # type: ignore
else:
mm_placeholder_counts = mm_parser.mm_placeholder_counts()
if mm_placeholder_counts:
text_prompt = _get_full_multimodal_text_prompt(
mm_placeholder_counts, text_prompt)
return [ConversationMessage(role=role, content=text_prompt)]
# No need to validate using Pydantic again # No need to validate using Pydantic again
......
...@@ -4,19 +4,18 @@ from http import HTTPStatus ...@@ -4,19 +4,18 @@ from http import HTTPStatus
from typing import Any from typing import Any
import uvicorn import uvicorn
from fastapi import FastAPI, Response from fastapi import FastAPI, Request, Response
from vllm import envs from vllm import envs
from vllm.engine.async_llm_engine import AsyncEngineDeadError from vllm.engine.async_llm_engine import AsyncEngineDeadError
from vllm.engine.protocol import AsyncEngineClient from vllm.engine.multiprocessing import MQEngineDeadError
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import find_process_using_port from vllm.utils import find_process_using_port
logger = init_logger(__name__) logger = init_logger(__name__)
async def serve_http(app: FastAPI, engine: AsyncEngineClient, async def serve_http(app: FastAPI, **uvicorn_kwargs: Any):
**uvicorn_kwargs: Any):
logger.info("Available routes are:") logger.info("Available routes are:")
for route in app.routes: for route in app.routes:
methods = getattr(route, "methods", None) methods = getattr(route, "methods", None)
...@@ -27,18 +26,9 @@ async def serve_http(app: FastAPI, engine: AsyncEngineClient, ...@@ -27,18 +26,9 @@ async def serve_http(app: FastAPI, engine: AsyncEngineClient,
logger.info("Route: %s, Methods: %s", path, ', '.join(methods)) logger.info("Route: %s, Methods: %s", path, ', '.join(methods))
# Set concurrency limits in uvicorn if running in multiprocessing mode
# since zmq has maximum socket limit of zmq.constants.SOCKET_LIMIT (65536).
if engine.limit_concurrency is not None:
logger.info(
"Launching Uvicorn with --limit_concurrency %s. To avoid this "
"limit at the expense of performance run with "
"--disable-frontend-multiprocessing", engine.limit_concurrency)
uvicorn_kwargs["limit_concurrency"] = engine.limit_concurrency
config = uvicorn.Config(app, **uvicorn_kwargs) config = uvicorn.Config(app, **uvicorn_kwargs)
server = uvicorn.Server(config) server = uvicorn.Server(config)
_add_shutdown_handlers(app, server, engine) _add_shutdown_handlers(app, server)
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
...@@ -64,19 +54,19 @@ async def serve_http(app: FastAPI, engine: AsyncEngineClient, ...@@ -64,19 +54,19 @@ async def serve_http(app: FastAPI, engine: AsyncEngineClient,
logger.debug( logger.debug(
"port %s is used by process %s launched with command:\n%s", "port %s is used by process %s launched with command:\n%s",
port, process, " ".join(process.cmdline())) port, process, " ".join(process.cmdline()))
logger.info("Gracefully stopping http server") logger.info("Shutting down FastAPI HTTP server.")
return server.shutdown() return server.shutdown()
def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server, def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server) -> None:
engine: AsyncEngineClient) -> None:
"""Adds handlers for fatal errors that should crash the server""" """Adds handlers for fatal errors that should crash the server"""
@app.exception_handler(RuntimeError) @app.exception_handler(RuntimeError)
async def runtime_error_handler(_, __): async def runtime_error_handler(request: Request, __):
"""On generic runtime error, check to see if the engine has died. """On generic runtime error, check to see if the engine has died.
It probably has, in which case the server will no longer be able to It probably has, in which case the server will no longer be able to
handle requests. Trigger a graceful shutdown with a SIGTERM.""" handle requests. Trigger a graceful shutdown with a SIGTERM."""
engine = request.app.state.engine_client
if (not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH and engine.errored if (not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH and engine.errored
and not engine.is_running): and not engine.is_running):
logger.fatal("AsyncLLMEngine has failed, terminating server " logger.fatal("AsyncLLMEngine has failed, terminating server "
...@@ -91,7 +81,7 @@ def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server, ...@@ -91,7 +81,7 @@ def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server,
return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR) return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR)
@app.exception_handler(AsyncEngineDeadError) @app.exception_handler(AsyncEngineDeadError)
async def engine_dead_handler(_, __): async def async_engine_dead_handler(_, __):
"""Kill the server if the async engine is already dead. It will """Kill the server if the async engine is already dead. It will
not handle any further requests.""" not handle any further requests."""
if not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH: if not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH:
...@@ -100,3 +90,14 @@ def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server, ...@@ -100,3 +90,14 @@ def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server,
server.should_exit = True server.should_exit = True
return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR) return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR)
@app.exception_handler(MQEngineDeadError)
async def mq_engine_dead_handler(_, __):
"""Kill the server if the mq engine is already dead. It will
not handle any further requests."""
if not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH:
logger.fatal("MQLLMEngine is already dead, terminating server "
"process")
server.should_exit = True
return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR)
import itertools
from contextlib import contextmanager from contextlib import contextmanager
from typing import ClassVar, List, Optional, Sequence, Union, cast, overload from dataclasses import dataclass
from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple,
Union, cast, overload)
from tqdm import tqdm from tqdm import tqdm
...@@ -29,6 +32,37 @@ from vllm.utils import Counter, deprecate_kwargs, is_list_of ...@@ -29,6 +32,37 @@ from vllm.utils import Counter, deprecate_kwargs, is_list_of
logger = init_logger(__name__) logger = init_logger(__name__)
@dataclass
class BeamSearchSequence:
"""A sequence for beam search.
It keeps track of the tokens and the log probability of the sequence.
The text field is optional and will only be filled when the sequence is
about to be returned to the user.
"""
# The tokens includes the prompt.
tokens: List[int]
cum_logprob: float = 0.0
text: Optional[str] = None
@dataclass
class BeamSearchOutput:
"""The output of beam search.
It contains the list of the best beam search sequences.
The length of the list is equal to the beam width.
"""
sequences: List[BeamSearchSequence]
class BeamSearchInstance:
def __init__(self, prompt_tokens: List[int]):
self.beams: List[BeamSearchSequence] = [
BeamSearchSequence(tokens=prompt_tokens)
]
self.completed: List[BeamSearchSequence] = []
class LLM: class LLM:
"""An LLM for generating texts from given prompts and sampling parameters. """An LLM for generating texts from given prompts and sampling parameters.
...@@ -88,7 +122,9 @@ class LLM: ...@@ -88,7 +122,9 @@ class LLM:
to eager mode (DEPRECATED. Use `max_seq_len_to_capture` instead). to eager mode (DEPRECATED. Use `max_seq_len_to_capture` instead).
max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs. max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
When a sequence has context length larger than this, we fall back When a sequence has context length larger than this, we fall back
to eager mode. to eager mode. Additionally for encoder-decoder models, if the
sequence length of the encoder input is larger than this, we fall
back to the eager mode.
disable_custom_all_reduce: See ParallelConfig disable_custom_all_reduce: See ParallelConfig
**kwargs: Arguments for :class:`~vllm.EngineArgs`. (See **kwargs: Arguments for :class:`~vllm.EngineArgs`. (See
:ref:`engine_args`) :ref:`engine_args`)
...@@ -131,15 +167,14 @@ class LLM: ...@@ -131,15 +167,14 @@ class LLM:
max_seq_len_to_capture: int = 8192, max_seq_len_to_capture: int = 8192,
disable_custom_all_reduce: bool = False, disable_custom_all_reduce: bool = False,
disable_async_output_proc: bool = False, disable_async_output_proc: bool = False,
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
**kwargs, **kwargs,
) -> None: ) -> None:
''' '''
LLM constructor. LLM constructor.
Note: if enforce_eager is unset (enforce_eager is None) Note: if enforce_eager is unset (enforce_eager is None)
it defaults to False for decoder-only models and True it defaults to False.
for encoder/decoder models, since encoder/decoder models
do not currently support CUDAGraph.
''' '''
if "disable_log_stats" not in kwargs: if "disable_log_stats" not in kwargs:
...@@ -173,6 +208,7 @@ class LLM: ...@@ -173,6 +208,7 @@ class LLM:
max_seq_len_to_capture=max_seq_len_to_capture, max_seq_len_to_capture=max_seq_len_to_capture,
disable_custom_all_reduce=disable_custom_all_reduce, disable_custom_all_reduce=disable_custom_all_reduce,
disable_async_output_proc=disable_async_output_proc, disable_async_output_proc=disable_async_output_proc,
mm_processor_kwargs=mm_processor_kwargs,
**kwargs, **kwargs,
) )
self.llm_engine = LLMEngine.from_engine_args( self.llm_engine = LLMEngine.from_engine_args(
...@@ -284,7 +320,8 @@ class LLM: ...@@ -284,7 +320,8 @@ class LLM:
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
guided_options_request: Optional[Union[LLMGuidedOptions, guided_options_request: Optional[Union[LLMGuidedOptions,
GuidedDecodingRequest]] = None GuidedDecodingRequest]] = None,
priority: Optional[List[int]] = None,
) -> List[RequestOutput]: ) -> List[RequestOutput]:
"""Generates the completions for the input prompts. """Generates the completions for the input prompts.
...@@ -303,6 +340,8 @@ class LLM: ...@@ -303,6 +340,8 @@ class LLM:
lora_request: LoRA request to use for generation, if any. lora_request: LoRA request to use for generation, if any.
prompt_adapter_request: Prompt Adapter request to use for prompt_adapter_request: Prompt Adapter request to use for
generation, if any. generation, if any.
priority: The priority of the requests, if any.
Only applicable when priority scheduling policy is enabled.
Returns: Returns:
A list of ``RequestOutput`` objects containing the A list of ``RequestOutput`` objects containing the
...@@ -343,20 +382,122 @@ class LLM: ...@@ -343,20 +382,122 @@ class LLM:
params=sampling_params, params=sampling_params,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request, prompt_adapter_request=prompt_adapter_request,
guided_options=guided_options_request) guided_options=guided_options_request,
priority=priority)
outputs = self._run_engine(use_tqdm=use_tqdm) outputs = self._run_engine(use_tqdm=use_tqdm)
return LLMEngine.validate_outputs(outputs, RequestOutput) return LLMEngine.validate_outputs(outputs, RequestOutput)
def beam_search(
self,
prompts: List[Union[str, List[int]]],
beam_width: int,
max_tokens: int,
ignore_eos: bool = False,
) -> List[BeamSearchOutput]:
"""
Generate sequences using beam search.
Args:
prompts: A list of prompts. Each prompt can be a string or a list
of token IDs.
beam_width: The number of beams to keep at each step.
max_tokens: The max number of tokens to generate for each prompt.
TODO: how does beam search work together with length penalty, frequency
penalty, and stopping criteria, etc.?
"""
tokenizer = self.get_tokenizer()
# generate 2 * beam_width candidates at each step
# following the huggingface transformers implementation
# at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
beam_search_params = SamplingParams(logprobs=2 * beam_width,
max_tokens=1,
temperature=0.0)
instances: List[BeamSearchInstance] = []
for prompt in prompts:
prompt_tokens = prompt if isinstance(
prompt, list) else tokenizer.encode(prompt)
instances.append(BeamSearchInstance(prompt_tokens))
for _ in range(max_tokens):
all_beams: List[BeamSearchSequence] = list(
sum((instance.beams for instance in instances), []))
pos = [0] + list(
itertools.accumulate(
len(instance.beams) for instance in instances))
instance_start_and_end: List[Tuple[int, int]] = list(
zip(pos[:-1], pos[1:]))
if len(all_beams) == 0:
break
prompts_batch = [
TokensPrompt(prompt_token_ids=beam.tokens)
for beam in all_beams
]
# only runs for one step
# we don't need to use tqdm here
output = self.generate(prompts_batch,
sampling_params=beam_search_params,
use_tqdm=False)
for (start, end), instance in zip(instance_start_and_end,
instances):
instance_new_beams = []
for i in range(start, end):
current_beam = all_beams[i]
result = output[i]
if result.outputs[0].logprobs is not None:
# if `result.outputs[0].logprobs` is None, it means
# the sequence is completed because of the max-model-len
# or abortion. we don't need to add it to the new beams.
logprobs = result.outputs[0].logprobs[0]
for token_id, logprob_obj in logprobs.items():
new_beam = BeamSearchSequence(
tokens=current_beam.tokens + [token_id],
cum_logprob=current_beam.cum_logprob +
logprob_obj.logprob)
if token_id == tokenizer.eos_token_id and \
not ignore_eos:
instance.completed.append(new_beam)
else:
instance_new_beams.append(new_beam)
sorted_beams = sorted(instance_new_beams,
key=lambda x: x.cum_logprob,
reverse=True)
instance.beams = sorted_beams[:beam_width]
outputs = []
for instance in instances:
instance.completed.extend(instance.beams)
sorted_completed = sorted(instance.completed,
key=lambda x: x.cum_logprob,
reverse=True)
best_beams = sorted_completed[:beam_width]
for beam in best_beams:
beam.text = tokenizer.decode(beam.tokens)
outputs.append(BeamSearchOutput(sequences=best_beams))
return outputs
def chat( def chat(
self, self,
messages: List[ChatCompletionMessageParam], messages: Union[List[ChatCompletionMessageParam],
List[List[ChatCompletionMessageParam]]],
sampling_params: Optional[Union[SamplingParams, sampling_params: Optional[Union[SamplingParams,
List[SamplingParams]]] = None, List[SamplingParams]]] = None,
use_tqdm: bool = True, use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
chat_template: Optional[str] = None, chat_template: Optional[str] = None,
add_generation_prompt: bool = True, add_generation_prompt: bool = True,
tools: Optional[List[Dict[str, Any]]] = None,
) -> List[RequestOutput]: ) -> List[RequestOutput]:
""" """
Generate responses for a chat conversation. Generate responses for a chat conversation.
...@@ -369,8 +510,9 @@ class LLM: ...@@ -369,8 +510,9 @@ class LLM:
to the OpenAI API. to the OpenAI API.
Args: Args:
messages: A single conversation represented as a list of messages. messages: A list of conversations or a single conversation.
Each message is a dictionary with 'role' and 'content' keys. - Each conversation is represented as a list of messages.
- Each message is a dictionary with 'role' and 'content' keys.
sampling_params: The sampling parameters for text generation. sampling_params: The sampling parameters for text generation.
If None, we use the default sampling parameters. When it If None, we use the default sampling parameters. When it
is a single value, it is applied to every prompt. When it is a single value, it is applied to every prompt. When it
...@@ -387,40 +529,56 @@ class LLM: ...@@ -387,40 +529,56 @@ class LLM:
A list of ``RequestOutput`` objects containing the generated A list of ``RequestOutput`` objects containing the generated
responses in the same order as the input messages. responses in the same order as the input messages.
""" """
list_of_messages: List[List[ChatCompletionMessageParam]]
tokenizer = self.get_tokenizer() # Handle multi and single conversations
model_config = self.llm_engine.get_model_config() if is_list_of(messages, list):
# messages is List[List[...]]
conversation, mm_data = parse_chat_messages(messages, model_config, list_of_messages = messages
tokenizer)
prompt: Union[str, List[int]]
if isinstance(tokenizer, MistralTokenizer):
prompt = apply_mistral_chat_template(
tokenizer,
messages=messages,
chat_template=chat_template,
add_generation_prompt=add_generation_prompt,
)
else: else:
prompt = apply_hf_chat_template( # messages is List[...]
tokenizer, list_of_messages = [messages]
conversation=conversation,
chat_template=chat_template, prompts: List[Union[TokensPrompt, TextPrompt]] = []
add_generation_prompt=add_generation_prompt,
) for msgs in list_of_messages:
tokenizer = self.get_tokenizer()
model_config = self.llm_engine.get_model_config()
conversation, mm_data = parse_chat_messages(
msgs, model_config, tokenizer)
prompt_data: Union[str, List[int]]
if isinstance(tokenizer, MistralTokenizer):
prompt_data = apply_mistral_chat_template(
tokenizer,
messages=msgs,
chat_template=chat_template,
add_generation_prompt=add_generation_prompt,
tools=tools,
)
else:
prompt_data = apply_hf_chat_template(
tokenizer,
conversation=conversation,
chat_template=chat_template,
add_generation_prompt=add_generation_prompt,
tools=tools,
)
prompt: Union[TokensPrompt, TextPrompt]
if is_list_of(prompt_data, int):
prompt = TokensPrompt(prompt_token_ids=prompt_data)
else:
prompt = TextPrompt(prompt=prompt_data)
inputs: PromptInputs if mm_data is not None:
if is_list_of(prompt, int): prompt["multi_modal_data"] = mm_data
inputs = TokensPrompt(prompt_token_ids=prompt)
else:
inputs = TextPrompt(prompt=prompt)
if mm_data is not None: prompts.append(prompt)
inputs["multi_modal_data"] = mm_data
return self.generate( return self.generate(
inputs, prompts,
sampling_params=sampling_params, sampling_params=sampling_params,
use_tqdm=use_tqdm, use_tqdm=use_tqdm,
lora_request=lora_request, lora_request=lora_request,
...@@ -628,6 +786,7 @@ class LLM: ...@@ -628,6 +786,7 @@ class LLM:
lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]], lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
prompt_adapter_request: Optional[PromptAdapterRequest], prompt_adapter_request: Optional[PromptAdapterRequest],
guided_options: Optional[GuidedDecodingRequest] = None, guided_options: Optional[GuidedDecodingRequest] = None,
priority: Optional[List[int]] = None,
) -> None: ) -> None:
if isinstance(inputs, (str, dict)): if isinstance(inputs, (str, dict)):
# Convert a single prompt to a list. # Convert a single prompt to a list.
...@@ -657,6 +816,7 @@ class LLM: ...@@ -657,6 +816,7 @@ class LLM:
lora_request=lora_request[i] if isinstance( lora_request=lora_request[i] if isinstance(
lora_request, Sequence) else lora_request, lora_request, Sequence) else lora_request,
prompt_adapter_request=prompt_adapter_request, prompt_adapter_request=prompt_adapter_request,
priority=priority[i] if priority else 0,
) )
def _add_request( def _add_request(
...@@ -665,6 +825,7 @@ class LLM: ...@@ -665,6 +825,7 @@ class LLM:
params: Union[SamplingParams, PoolingParams], params: Union[SamplingParams, PoolingParams],
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> None: ) -> None:
request_id = str(next(self.request_counter)) request_id = str(next(self.request_counter))
self.llm_engine.add_request( self.llm_engine.add_request(
...@@ -673,6 +834,7 @@ class LLM: ...@@ -673,6 +834,7 @@ class LLM:
params, params,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request, prompt_adapter_request=prompt_adapter_request,
priority=priority,
) )
def _add_guided_processor( def _add_guided_processor(
......
...@@ -4,16 +4,21 @@ import inspect ...@@ -4,16 +4,21 @@ import inspect
import multiprocessing import multiprocessing
import os import os
import re import re
import signal
import socket
import tempfile import tempfile
from argparse import Namespace from argparse import Namespace
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from functools import partial
from http import HTTPStatus from http import HTTPStatus
from typing import AsyncIterator, Optional, Set from typing import AsyncIterator, Set
import uvloop
from fastapi import APIRouter, FastAPI, Request from fastapi import APIRouter, FastAPI, Request
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, Response, StreamingResponse from fastapi.responses import JSONResponse, Response, StreamingResponse
from starlette.datastructures import State
from starlette.routing import Mount from starlette.routing import Mount
from typing_extensions import assert_never from typing_extensions import assert_never
...@@ -21,7 +26,9 @@ import vllm.envs as envs ...@@ -21,7 +26,9 @@ import vllm.envs as envs
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.protocol import AsyncEngineClient from vllm.engine.multiprocessing.client import MQLLMEngineClient
from vllm.engine.multiprocessing.engine import run_mp_engine
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.launcher import serve_http from vllm.entrypoints.launcher import serve_http
from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.cli_args import make_arg_parser from vllm.entrypoints.openai.cli_args import make_arg_parser
...@@ -39,12 +46,11 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, ...@@ -39,12 +46,11 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
TokenizeRequest, TokenizeRequest,
TokenizeResponse, TokenizeResponse,
UnloadLoraAdapterRequest) UnloadLoraAdapterRequest)
from vllm.entrypoints.openai.rpc.client import AsyncEngineRPCClient
from vllm.entrypoints.openai.rpc.server import run_rpc_server
# yapf: enable # yapf: enable
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
from vllm.entrypoints.openai.serving_engine import BaseModelPath
from vllm.entrypoints.openai.serving_tokenization import ( from vllm.entrypoints.openai.serving_tokenization import (
OpenAIServingTokenization) OpenAIServingTokenization)
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -54,12 +60,6 @@ from vllm.version import __version__ as VLLM_VERSION ...@@ -54,12 +60,6 @@ from vllm.version import __version__ as VLLM_VERSION
TIMEOUT_KEEP_ALIVE = 5 # seconds TIMEOUT_KEEP_ALIVE = 5 # seconds
async_engine_client: AsyncEngineClient
engine_args: AsyncEngineArgs
openai_serving_chat: OpenAIServingChat
openai_serving_completion: OpenAIServingCompletion
openai_serving_embedding: OpenAIServingEmbedding
openai_serving_tokenization: OpenAIServingTokenization
prometheus_multiproc_dir: tempfile.TemporaryDirectory prometheus_multiproc_dir: tempfile.TemporaryDirectory
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765) # Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
...@@ -68,49 +68,42 @@ logger = init_logger('vllm.entrypoints.openai.api_server') ...@@ -68,49 +68,42 @@ logger = init_logger('vllm.entrypoints.openai.api_server')
_running_tasks: Set[asyncio.Task] = set() _running_tasks: Set[asyncio.Task] = set()
def model_is_embedding(model_name: str, trust_remote_code: bool,
quantization: Optional[str]) -> bool:
return ModelConfig(model=model_name,
tokenizer=model_name,
tokenizer_mode="auto",
trust_remote_code=trust_remote_code,
quantization=quantization,
seed=0,
dtype="auto").embedding_mode
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
try:
async def _force_log(): if app.state.log_stats:
while True: engine_client: EngineClient = app.state.engine_client
await asyncio.sleep(10)
await async_engine_client.do_log_stats() async def _force_log():
while True:
if not engine_args.disable_log_stats: await asyncio.sleep(10.)
task = asyncio.create_task(_force_log()) await engine_client.do_log_stats()
_running_tasks.add(task)
task.add_done_callback(_running_tasks.remove) task = asyncio.create_task(_force_log())
_running_tasks.add(task)
yield task.add_done_callback(_running_tasks.remove)
else:
task = None
try:
yield
finally:
if task is not None:
task.cancel()
finally:
# Ensure app state including engine ref is gc'd
del app.state
@asynccontextmanager @asynccontextmanager
async def build_async_engine_client( async def build_async_engine_client(
args: Namespace) -> AsyncIterator[Optional[AsyncEngineClient]]: args: Namespace) -> AsyncIterator[EngineClient]:
# Context manager to handle async_engine_client lifecycle # Context manager to handle engine_client lifecycle
# Ensures everything is shutdown and cleaned up on error/exit # Ensures everything is shutdown and cleaned up on error/exit
global engine_args
engine_args = AsyncEngineArgs.from_cli_args(args) engine_args = AsyncEngineArgs.from_cli_args(args)
# Backend itself still global for the silly lil' health handler
global async_engine_client
async with build_async_engine_client_from_engine_args( async with build_async_engine_client_from_engine_args(
engine_args, args.disable_frontend_multiprocessing) as engine: engine_args, args.disable_frontend_multiprocessing) as engine:
async_engine_client = engine # type: ignore[assignment]
yield engine yield engine
...@@ -118,26 +111,35 @@ async def build_async_engine_client( ...@@ -118,26 +111,35 @@ async def build_async_engine_client(
async def build_async_engine_client_from_engine_args( async def build_async_engine_client_from_engine_args(
engine_args: AsyncEngineArgs, engine_args: AsyncEngineArgs,
disable_frontend_multiprocessing: bool = False, disable_frontend_multiprocessing: bool = False,
) -> AsyncIterator[Optional[AsyncEngineClient]]: ) -> AsyncIterator[EngineClient]:
""" """
Create AsyncEngineClient, either: Create EngineClient, either:
- in-process using the AsyncLLMEngine Directly - in-process using the AsyncLLMEngine Directly
- multiprocess using AsyncLLMEngine RPC - multiprocess using AsyncLLMEngine RPC
Returns the Client or None if the creation failed. Returns the Client or None if the creation failed.
""" """
# If manually triggered or embedding model, use AsyncLLMEngine in process. # Fall back
# TODO: support embedding model via RPC. # TODO: fill out feature matrix.
if (model_is_embedding(engine_args.model, engine_args.trust_remote_code, if (MQLLMEngineClient.is_unsupported_config(engine_args)
engine_args.quantization)
or disable_frontend_multiprocessing): or disable_frontend_multiprocessing):
engine_client = AsyncLLMEngine.from_engine_args( engine_config = engine_args.create_engine_config()
engine_args, usage_context=UsageContext.OPENAI_API_SERVER) uses_ray = getattr(AsyncLLMEngine._get_executor_cls(engine_config),
try: "uses_ray", False)
yield engine_client
finally: build_engine = partial(AsyncLLMEngine.from_engine_args,
engine_client.shutdown_background_loop() engine_args=engine_args,
engine_config=engine_config,
usage_context=UsageContext.OPENAI_API_SERVER)
if uses_ray:
# Must run in main thread with ray for its signal handlers to work
engine_client = build_engine()
else:
engine_client = await asyncio.get_running_loop().run_in_executor(
None, build_engine)
yield engine_client
return return
# Otherwise, use the multiprocessing AsyncLLMEngine. # Otherwise, use the multiprocessing AsyncLLMEngine.
...@@ -158,56 +160,58 @@ async def build_async_engine_client_from_engine_args( ...@@ -158,56 +160,58 @@ async def build_async_engine_client_from_engine_args(
"and vLLM will properly handle cleanup.") "and vLLM will properly handle cleanup.")
# Select random path for IPC. # Select random path for IPC.
rpc_path = get_open_zmq_ipc_path() ipc_path = get_open_zmq_ipc_path()
logger.info("Multiprocessing frontend to use %s for RPC Path.", logger.info("Multiprocessing frontend to use %s for IPC Path.",
rpc_path) ipc_path)
# Build RPCClient, which conforms to AsyncEngineClient Protocol.
# NOTE: Actually, this is not true yet. We still need to support
# embedding models via RPC (see TODO above)
rpc_client = AsyncEngineRPCClient(rpc_path)
# Start RPCServer in separate process (holds the AsyncLLMEngine). # Start RPCServer in separate process (holds the LLMEngine).
context = multiprocessing.get_context("spawn")
# the current process might have CUDA context, # the current process might have CUDA context,
# so we need to spawn a new process # so we need to spawn a new process
rpc_server_process = context.Process( context = multiprocessing.get_context("spawn")
target=run_rpc_server,
args=(engine_args, UsageContext.OPENAI_API_SERVER, rpc_path)) engine_process = context.Process(target=run_mp_engine,
rpc_server_process.start() args=(engine_args,
logger.info("Started engine process with PID %d", UsageContext.OPENAI_API_SERVER,
rpc_server_process.pid) ipc_path))
engine_process.start()
logger.info("Started engine process with PID %d", engine_process.pid)
# Build RPCClient, which conforms to EngineClient Protocol.
# NOTE: Actually, this is not true yet. We still need to support
# embedding models via RPC (see TODO above)
engine_config = engine_args.create_engine_config()
mp_engine_client = MQLLMEngineClient(ipc_path, engine_config)
try: try:
while True: while True:
try: try:
await rpc_client.setup() await mp_engine_client.setup()
break break
except TimeoutError: except TimeoutError:
if not rpc_server_process.is_alive(): if not engine_process.is_alive():
logger.error( raise RuntimeError(
"RPCServer process died before responding " "Engine process failed to start") from None
"to readiness probe")
yield None yield mp_engine_client # type: ignore[misc]
return
yield rpc_client # type: ignore[misc]
finally: finally:
# Ensure rpc server process was terminated # Ensure rpc server process was terminated
rpc_server_process.terminate() engine_process.terminate()
# Close all open connections to the backend # Close all open connections to the backend
rpc_client.close() mp_engine_client.close()
# Wait for server process to join # Wait for engine process to join
rpc_server_process.join() engine_process.join(4)
if engine_process.exitcode is None:
# Kill if taking longer than 5 seconds to stop
engine_process.kill()
# Lazy import for prometheus multiprocessing. # Lazy import for prometheus multiprocessing.
# We need to set PROMETHEUS_MULTIPROC_DIR environment variable # We need to set PROMETHEUS_MULTIPROC_DIR environment variable
# before prometheus_client is imported. # before prometheus_client is imported.
# See https://prometheus.github.io/client_python/multiprocess/ # See https://prometheus.github.io/client_python/multiprocess/
from prometheus_client import multiprocess from prometheus_client import multiprocess
multiprocess.mark_process_dead(rpc_server_process.pid) multiprocess.mark_process_dead(engine_process.pid)
router = APIRouter() router = APIRouter()
...@@ -239,16 +243,36 @@ def mount_metrics(app: FastAPI): ...@@ -239,16 +243,36 @@ def mount_metrics(app: FastAPI):
app.routes.append(metrics_route) app.routes.append(metrics_route)
def chat(request: Request) -> OpenAIServingChat:
return request.app.state.openai_serving_chat
def completion(request: Request) -> OpenAIServingCompletion:
return request.app.state.openai_serving_completion
def tokenization(request: Request) -> OpenAIServingTokenization:
return request.app.state.openai_serving_tokenization
def embedding(request: Request) -> OpenAIServingEmbedding:
return request.app.state.openai_serving_embedding
def engine_client(request: Request) -> EngineClient:
return request.app.state.engine_client
@router.get("/health") @router.get("/health")
async def health() -> Response: async def health(raw_request: Request) -> Response:
"""Health check.""" """Health check."""
await async_engine_client.check_health() await engine_client(raw_request).check_health()
return Response(status_code=200) return Response(status_code=200)
@router.post("/tokenize") @router.post("/tokenize")
async def tokenize(request: TokenizeRequest): async def tokenize(request: TokenizeRequest, raw_request: Request):
generator = await openai_serving_tokenization.create_tokenize(request) generator = await tokenization(raw_request).create_tokenize(request)
if isinstance(generator, ErrorResponse): if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(), return JSONResponse(content=generator.model_dump(),
status_code=generator.code) status_code=generator.code)
...@@ -259,8 +283,8 @@ async def tokenize(request: TokenizeRequest): ...@@ -259,8 +283,8 @@ async def tokenize(request: TokenizeRequest):
@router.post("/detokenize") @router.post("/detokenize")
async def detokenize(request: DetokenizeRequest): async def detokenize(request: DetokenizeRequest, raw_request: Request):
generator = await openai_serving_tokenization.create_detokenize(request) generator = await tokenization(raw_request).create_detokenize(request)
if isinstance(generator, ErrorResponse): if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(), return JSONResponse(content=generator.model_dump(),
status_code=generator.code) status_code=generator.code)
...@@ -271,8 +295,8 @@ async def detokenize(request: DetokenizeRequest): ...@@ -271,8 +295,8 @@ async def detokenize(request: DetokenizeRequest):
@router.get("/v1/models") @router.get("/v1/models")
async def show_available_models(): async def show_available_models(raw_request: Request):
models = await openai_serving_completion.show_available_models() models = await completion(raw_request).show_available_models()
return JSONResponse(content=models.model_dump()) return JSONResponse(content=models.model_dump())
...@@ -286,7 +310,7 @@ async def show_version(): ...@@ -286,7 +310,7 @@ async def show_version():
async def create_chat_completion(request: ChatCompletionRequest, async def create_chat_completion(request: ChatCompletionRequest,
raw_request: Request): raw_request: Request):
generator = await openai_serving_chat.create_chat_completion( generator = await chat(raw_request).create_chat_completion(
request, raw_request) request, raw_request)
if isinstance(generator, ErrorResponse): if isinstance(generator, ErrorResponse):
...@@ -301,7 +325,7 @@ async def create_chat_completion(request: ChatCompletionRequest, ...@@ -301,7 +325,7 @@ async def create_chat_completion(request: ChatCompletionRequest,
@router.post("/v1/completions") @router.post("/v1/completions")
async def create_completion(request: CompletionRequest, raw_request: Request): async def create_completion(request: CompletionRequest, raw_request: Request):
generator = await openai_serving_completion.create_completion( generator = await completion(raw_request).create_completion(
request, raw_request) request, raw_request)
if isinstance(generator, ErrorResponse): if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(), return JSONResponse(content=generator.model_dump(),
...@@ -314,7 +338,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request): ...@@ -314,7 +338,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
@router.post("/v1/embeddings") @router.post("/v1/embeddings")
async def create_embedding(request: EmbeddingRequest, raw_request: Request): async def create_embedding(request: EmbeddingRequest, raw_request: Request):
generator = await openai_serving_embedding.create_embedding( generator = await embedding(raw_request).create_embedding(
request, raw_request) request, raw_request)
if isinstance(generator, ErrorResponse): if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(), return JSONResponse(content=generator.model_dump(),
...@@ -331,16 +355,16 @@ if envs.VLLM_TORCH_PROFILER_DIR: ...@@ -331,16 +355,16 @@ if envs.VLLM_TORCH_PROFILER_DIR:
"used for local development!") "used for local development!")
@router.post("/start_profile") @router.post("/start_profile")
async def start_profile(): async def start_profile(raw_request: Request):
logger.info("Starting profiler...") logger.info("Starting profiler...")
await async_engine_client.start_profile() await engine_client(raw_request).start_profile()
logger.info("Profiler started.") logger.info("Profiler started.")
return Response(status_code=200) return Response(status_code=200)
@router.post("/stop_profile") @router.post("/stop_profile")
async def stop_profile(): async def stop_profile(raw_request: Request):
logger.info("Stopping profiler...") logger.info("Stopping profiler...")
await async_engine_client.stop_profile() await engine_client(raw_request).stop_profile()
logger.info("Profiler stopped.") logger.info("Profiler stopped.")
return Response(status_code=200) return Response(status_code=200)
...@@ -351,13 +375,14 @@ if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING: ...@@ -351,13 +375,14 @@ if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
"This should ONLY be used for local development!") "This should ONLY be used for local development!")
@router.post("/v1/load_lora_adapter") @router.post("/v1/load_lora_adapter")
async def load_lora_adapter(request: LoadLoraAdapterRequest): async def load_lora_adapter(request: LoadLoraAdapterRequest,
response = await openai_serving_chat.load_lora_adapter(request) raw_request: Request):
response = await chat(raw_request).load_lora_adapter(request)
if isinstance(response, ErrorResponse): if isinstance(response, ErrorResponse):
return JSONResponse(content=response.model_dump(), return JSONResponse(content=response.model_dump(),
status_code=response.code) status_code=response.code)
response = await openai_serving_completion.load_lora_adapter(request) response = await completion(raw_request).load_lora_adapter(request)
if isinstance(response, ErrorResponse): if isinstance(response, ErrorResponse):
return JSONResponse(content=response.model_dump(), return JSONResponse(content=response.model_dump(),
status_code=response.code) status_code=response.code)
...@@ -365,13 +390,14 @@ if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING: ...@@ -365,13 +390,14 @@ if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
return Response(status_code=200, content=response) return Response(status_code=200, content=response)
@router.post("/v1/unload_lora_adapter") @router.post("/v1/unload_lora_adapter")
async def unload_lora_adapter(request: UnloadLoraAdapterRequest): async def unload_lora_adapter(request: UnloadLoraAdapterRequest,
response = await openai_serving_chat.unload_lora_adapter(request) raw_request: Request):
response = await chat(raw_request).unload_lora_adapter(request)
if isinstance(response, ErrorResponse): if isinstance(response, ErrorResponse):
return JSONResponse(content=response.model_dump(), return JSONResponse(content=response.model_dump(),
status_code=response.code) status_code=response.code)
response = await openai_serving_completion.unload_lora_adapter(request) response = await completion(raw_request).unload_lora_adapter(request)
if isinstance(response, ErrorResponse): if isinstance(response, ErrorResponse):
return JSONResponse(content=response.model_dump(), return JSONResponse(content=response.model_dump(),
status_code=response.code) status_code=response.code)
...@@ -380,7 +406,13 @@ if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING: ...@@ -380,7 +406,13 @@ if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
def build_app(args: Namespace) -> FastAPI: def build_app(args: Namespace) -> FastAPI:
app = FastAPI(lifespan=lifespan) if args.disable_fastapi_docs:
app = FastAPI(openapi_url=None,
docs_url=None,
redoc_url=None,
lifespan=lifespan)
else:
app = FastAPI(lifespan=lifespan)
app.include_router(router) app.include_router(router)
app.root_path = args.root_path app.root_path = args.root_path
...@@ -396,7 +428,8 @@ def build_app(args: Namespace) -> FastAPI: ...@@ -396,7 +428,8 @@ def build_app(args: Namespace) -> FastAPI:
@app.exception_handler(RequestValidationError) @app.exception_handler(RequestValidationError)
async def validation_exception_handler(_, exc): async def validation_exception_handler(_, exc):
err = openai_serving_chat.create_error_response(message=str(exc)) chat = app.state.openai_serving_chat
err = chat.create_error_response(message=str(exc))
return JSONResponse(err.model_dump(), return JSONResponse(err.model_dump(),
status_code=HTTPStatus.BAD_REQUEST) status_code=HTTPStatus.BAD_REQUEST)
...@@ -428,33 +461,34 @@ def build_app(args: Namespace) -> FastAPI: ...@@ -428,33 +461,34 @@ def build_app(args: Namespace) -> FastAPI:
return app return app
async def init_app( def init_app_state(
async_engine_client: AsyncEngineClient, engine_client: EngineClient,
model_config: ModelConfig,
state: State,
args: Namespace, args: Namespace,
) -> FastAPI: ) -> None:
app = build_app(args)
if args.served_model_name is not None: if args.served_model_name is not None:
served_model_names = args.served_model_name served_model_names = args.served_model_name
else: else:
served_model_names = [args.model] served_model_names = [args.model]
model_config = await async_engine_client.get_model_config()
if args.disable_log_requests: if args.disable_log_requests:
request_logger = None request_logger = None
else: else:
request_logger = RequestLogger(max_log_len=args.max_log_len) request_logger = RequestLogger(max_log_len=args.max_log_len)
global openai_serving_chat base_model_paths = [
global openai_serving_completion BaseModelPath(name=name, model_path=args.model)
global openai_serving_embedding for name in served_model_names
global openai_serving_tokenization ]
openai_serving_chat = OpenAIServingChat( state.engine_client = engine_client
async_engine_client, state.log_stats = not args.disable_log_stats
state.openai_serving_chat = OpenAIServingChat(
engine_client,
model_config, model_config,
served_model_names, base_model_paths,
args.response_role, args.response_role,
lora_modules=args.lora_modules, lora_modules=args.lora_modules,
prompt_adapters=args.prompt_adapters, prompt_adapters=args.prompt_adapters,
...@@ -463,48 +497,54 @@ async def init_app( ...@@ -463,48 +497,54 @@ async def init_app(
return_tokens_as_token_ids=args.return_tokens_as_token_ids, return_tokens_as_token_ids=args.return_tokens_as_token_ids,
enable_auto_tools=args.enable_auto_tool_choice, enable_auto_tools=args.enable_auto_tool_choice,
tool_parser=args.tool_call_parser) tool_parser=args.tool_call_parser)
openai_serving_completion = OpenAIServingCompletion( state.openai_serving_completion = OpenAIServingCompletion(
async_engine_client, engine_client,
model_config, model_config,
served_model_names, base_model_paths,
lora_modules=args.lora_modules, lora_modules=args.lora_modules,
prompt_adapters=args.prompt_adapters, prompt_adapters=args.prompt_adapters,
request_logger=request_logger, request_logger=request_logger,
return_tokens_as_token_ids=args.return_tokens_as_token_ids, return_tokens_as_token_ids=args.return_tokens_as_token_ids,
) )
openai_serving_embedding = OpenAIServingEmbedding( state.openai_serving_embedding = OpenAIServingEmbedding(
async_engine_client, engine_client,
model_config, model_config,
served_model_names, base_model_paths,
request_logger=request_logger, request_logger=request_logger,
) )
openai_serving_tokenization = OpenAIServingTokenization( state.openai_serving_tokenization = OpenAIServingTokenization(
async_engine_client, engine_client,
model_config, model_config,
served_model_names, base_model_paths,
lora_modules=args.lora_modules, lora_modules=args.lora_modules,
request_logger=request_logger, request_logger=request_logger,
chat_template=args.chat_template, chat_template=args.chat_template,
) )
app.root_path = args.root_path
return app
async def run_server(args, **uvicorn_kwargs) -> None: async def run_server(args, **uvicorn_kwargs) -> None:
logger.info("vLLM API server version %s", VLLM_VERSION) logger.info("vLLM API server version %s", VLLM_VERSION)
logger.info("args: %s", args) logger.info("args: %s", args)
async with build_async_engine_client(args) as async_engine_client: temp_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
# If None, creation of the client failed and we exit. temp_socket.bind(("", args.port))
if async_engine_client is None:
return def signal_handler(*_) -> None:
# Interrupt server on sigterm while initializing
raise KeyboardInterrupt("terminated")
signal.signal(signal.SIGTERM, signal_handler)
async with build_async_engine_client(args) as engine_client:
app = build_app(args)
model_config = await engine_client.get_model_config()
init_app_state(engine_client, model_config, app.state, args)
app = await init_app(async_engine_client, args) temp_socket.close()
shutdown_task = await serve_http( shutdown_task = await serve_http(
app, app,
engine=async_engine_client,
host=args.host, host=args.host,
port=args.port, port=args.port,
log_level=args.uvicorn_log_level, log_level=args.uvicorn_log_level,
...@@ -528,4 +568,4 @@ if __name__ == "__main__": ...@@ -528,4 +568,4 @@ if __name__ == "__main__":
parser = make_arg_parser(parser) parser = make_arg_parser(parser)
args = parser.parse_args() args = parser.parse_args()
asyncio.run(run_server(args)) uvloop.run(run_server(args))
...@@ -31,8 +31,23 @@ class LoRAParserAction(argparse.Action): ...@@ -31,8 +31,23 @@ class LoRAParserAction(argparse.Action):
lora_list: List[LoRAModulePath] = [] lora_list: List[LoRAModulePath] = []
for item in values: for item in values:
name, path = item.split('=') if item in [None, '']: # Skip if item is None or empty string
lora_list.append(LoRAModulePath(name, path)) continue
if '=' in item and ',' not in item: # Old format: name=path
name, path = item.split('=')
lora_list.append(LoRAModulePath(name, path))
else: # Assume JSON format
try:
lora_dict = json.loads(item)
lora = LoRAModulePath(**lora_dict)
lora_list.append(lora)
except json.JSONDecodeError:
parser.error(
f"Invalid JSON format for --lora-modules: {item}")
except TypeError as e:
parser.error(
f"Invalid fields for --lora-modules: {item} - {str(e)}"
)
setattr(namespace, self.dest, lora_list) setattr(namespace, self.dest, lora_list)
...@@ -95,8 +110,12 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: ...@@ -95,8 +110,12 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
default=None, default=None,
nargs='+', nargs='+',
action=LoRAParserAction, action=LoRAParserAction,
help="LoRA module configurations in the format name=path. " help="LoRA module configurations in either 'name=path' format"
"Multiple modules can be specified.") "or JSON format. "
"Example (old format): 'name=path' "
"Example (new format): "
"'{\"name\": \"name\", \"local_path\": \"path\", "
"\"base_model_name\": \"id\"}'")
parser.add_argument( parser.add_argument(
"--prompt-adapters", "--prompt-adapters",
type=nullable_str, type=nullable_str,
...@@ -190,6 +209,13 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: ...@@ -190,6 +209,13 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
'ID numbers being printed in log.' 'ID numbers being printed in log.'
'\n\nDefault: Unlimited') '\n\nDefault: Unlimited')
parser.add_argument(
"--disable-fastapi-docs",
action='store_true',
default=False,
help="Disable FastAPI's OpenAPI schema, Swagger UI, and ReDoc endpoint"
)
return parser return parser
......
...@@ -107,6 +107,11 @@ class UsageInfo(OpenAIBaseModel): ...@@ -107,6 +107,11 @@ class UsageInfo(OpenAIBaseModel):
completion_tokens: Optional[int] = 0 completion_tokens: Optional[int] = 0
class RequestResponseMetadata(BaseModel):
request_id: str
final_usage_info: Optional[UsageInfo] = None
class JsonSchemaResponseFormat(OpenAIBaseModel): class JsonSchemaResponseFormat(OpenAIBaseModel):
name: str name: str
description: Optional[str] = None description: Optional[str] = None
......
import asyncio
import pickle
from contextlib import contextmanager, suppress
from typing import Any, AsyncGenerator, Iterator, Mapping, Optional
from uuid import uuid4
import cloudpickle
import zmq
import zmq.asyncio
from zmq import Frame # type: ignore[attr-defined]
from zmq.asyncio import Socket
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig)
# yapf: disable
from vllm.entrypoints.openai.rpc import (RPC_REQUEST_TYPE,
VLLM_RPC_SOCKET_LIMIT_CUTOFF,
VLLM_RPC_SUCCESS_STR,
VLLM_RPC_ZMQ_HWM, RPCAbortRequest,
RPCGenerateRequest, RPCUtilityRequest)
# yapf: enable
from vllm.envs import VLLM_RPC_GET_DATA_TIMEOUT_MS
from vllm.inputs import PromptInputs
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
logger = init_logger(__name__)
# Path used for inprocess proxy.
INPROC_PROXY_PATH = f"inproc://{uuid4()}"
class RPCClientClosedError(Exception):
"""Exception class raised when the client is used post-close.
The client can be closed, which closes the ZMQ context. This normally
happens on server shutdown. In some cases, methods like abort and
do_log_stats will still be called and then try to open a socket, which
causes a ZMQError and creates a huge stack trace.
So, we throw this error such that we can suppress it.
"""
class AsyncEngineRPCClient:
"""
RPCClient that connects to the RPCServer wrapping AsyncLLMEngine.
The overall design mirrors the Asynchronous Client Server Pattern
https://zguide.zeromq.org/docs/chapter3/#The-Asynchronous-Client-Server-Pattern
On startup, the RPCClient:
- makes DEALER socket (to_rpc_server) that connects to the RPCServer
via ipc, which uses unix sockets under the hood
(https://libzmq.readthedocs.io/en/zeromq4-1/zmq_ipc.html)
- makes ROUTER socket (from_api_server) that binds to a random
inproc address, which uses memory under the hood
(https://libzmq.readthedocs.io/en/zeromq3-x/zmq_inproc.html)
- runs a proxy in a background asyncio task between
from_api_server (ROUTER, inproc) and to_rpc_server (DEALER ipc, )
Each request handled by the asyncio api_server calls generate():
- make a DEALER socket that connects to from_api_server via inproc
- send a RCPGenerateRequest to the inproc socket
- background proxy forwards the request from inproc -> ipc
- RPCServer responds to the request one token at a time over ipc
- background proxy forwards the response from ipc -> inproc
The connection looks like this:
DEALER <- inproc -> [ ROUTER | DEALER ] <- ipc -> DEALER
Message routing is performed via identities that are managed by the
ROUTER socket. ROUTER sockets track every connection it has and
tells the caller about these. The way it tells the caller is to stick
the connection identity in front of each message received. When we
send the message via a ROUTER, we first send an identity frame.
See https://zguide.zeromq.org/docs/chapter3/#The-Extended-Reply-Envelope
for more details on connection identities.
This proxy design enables us to use a single unix socket, which
improves performance by avoiding syscalls (~5%) and avoids resource limits
such as ulimit, which defaults to 1024 on ubuntu.
Note: we run set_hwm(0) on each socket, which sets the HWM to inf,
which is required to avoid dropping messages under high load.
This is generally not advisable. However, since we are in control
of both sides of the connection + failure on either side is
catastrophic to the overall system health and memory profiling
suggests limited memory overhead relative to asyncio, we will
proceed for now.
See https://zguide.zeromq.org/docs/chapter2/#High-Water-Marks
for more details on high water marks.
"""
def __init__(self, rpc_path: str):
self.context = zmq.asyncio.Context()
self._data_timeout = VLLM_RPC_GET_DATA_TIMEOUT_MS
self._errored = False
# Maximum number of sockets that can be opened (typically 65536).
# ZMQ_SOCKET_LIMIT (http://api.zeromq.org/4-2:zmq-ctx-get)
socket_limit = self.context.get(zmq.constants.SOCKET_LIMIT)
assert isinstance(socket_limit, int)
if socket_limit < VLLM_RPC_SOCKET_LIMIT_CUTOFF:
raise ValueError(
f"Found zmq.constants.SOCKET_LIMIT={socket_limit}, which caps "
"the number of concurrent requests vLLM can process. Launch "
"vLLM with --disable-frontend-multiprocessing and open a "
"GitHub issue so we can investigate.")
# We only have 1 ipc connection that uses unix sockets, so
# safe to set MAX_SOCKETS to the zmq SOCKET_LIMIT (i.e. will
# not run into ulimit issues)
self.context.set(zmq.constants.MAX_SOCKETS, socket_limit)
# IPC connection to RPC Server (uses unix sockets).
self.to_rpc_server: Socket = self.context.socket(zmq.constants.DEALER)
self.to_rpc_server.set_hwm(VLLM_RPC_ZMQ_HWM)
self.to_rpc_server.bind(rpc_path)
# In process proxy to RPC Server (uses memory-based messaging).
self.from_api_server: Socket = self.context.socket(
zmq.constants.ROUTER)
self.from_api_server.set_hwm(VLLM_RPC_ZMQ_HWM)
self.from_api_server.bind(INPROC_PROXY_PATH)
# Asyncio background task for the proxy.
self.proxy_in_task = asyncio.create_task(
self.run_proxy(self.from_api_server, self.to_rpc_server))
self.proxy_out_task = asyncio.create_task(
self.run_proxy(self.to_rpc_server, self.from_api_server))
# Since we open 1 inproc socket per request, we have a hard cap on
# the number of requests that can run in vLLM w. frontend
# mulitprocessing. This value is used uvicorn to launch
# with --limit-concurrency to return 503 when server is overloaded.
# We need 2 sockets per request - 2:
# 1 for generate(), 1 for abort(), do_log_stats(), check_health()
self.limit_concurrency = socket_limit // 2 - 2
async def run_proxy(self, socket_from: Socket, socket_to: Socket):
"""Background task that runs a proxy"""
while True:
frames = await socket_from.recv_multipart(copy=False)
await socket_to.send_multipart(frames, copy=False)
async def setup(self):
"""Setup the client before it starts sending server requests."""
# Wait until server is ready.
await self._wait_for_server_rpc()
# Get the configs.
self.model_config = await self._get_model_config_rpc()
self.decoding_config = await self._get_decoding_config_rpc()
self.tracing_flag = await self._is_tracing_enabled_rpc()
# Create the tokenizer group.
# TODO: refactor OAI server to avoid needing this info.
self.tokenizer = init_tokenizer_from_configs(
model_config=self.model_config,
scheduler_config=(await self._get_scheduler_config_rpc()),
parallel_config=(await self._get_parallel_config_rpc()),
enable_lora=bool(await self._get_lora_config_rpc()),
)
def close(self):
"""Destroy the ZeroMQ Context."""
# Close all sockets associated with this context and
# then terminate the context.
self.from_api_server.close()
self.to_rpc_server.close()
self.context.destroy()
@contextmanager
def to_proxy_socket(self) -> Iterator[Socket]:
# Connect to the RPCServer via the proxy.
# Raise a sensible error if the client was already closed.
# This can happen if a server shutdown is triggered but some coroutines
# are still running requests.
# There should not be a race condition with this check because we don't
# yield to the event loop between here and opening the socket.
if self.context.closed:
raise RPCClientClosedError("The ZMQ client has already shut down")
# Note that we use DEALER to enable asynchronous communication
# to enable streaming.
socket = self.context.socket(zmq.constants.DEALER)
socket.set_hwm(VLLM_RPC_ZMQ_HWM)
try:
socket.connect(INPROC_PROXY_PATH)
yield socket
finally:
socket.close(linger=0)
async def _send_get_data_rpc_request(self, request: RPCUtilityRequest,
expected_type: Any,
error_message: str) -> Any:
"""Send an RPC request that is expecting data back."""
with self.to_proxy_socket() as socket:
# Ping RPCServer with a request.
await socket.send_multipart((cloudpickle.dumps(request), ),
copy=False)
# Make sure the server responds
if await socket.poll(timeout=self._data_timeout) == 0:
raise TimeoutError("Server didn't reply within "
f"{self._data_timeout} ms")
# Await the data from the Server.
frame = await socket.recv(copy=False)
assert isinstance(frame, Frame)
data = pickle.loads(frame.buffer)
if isinstance(data, Exception):
# Re-raise exceptions returned by the server
raise data
if not isinstance(data, expected_type):
# LoRAConfig can be None.
if expected_type == LoRAConfig and data is None:
pass
elif isinstance(data, Exception):
logger.error(error_message)
raise data
else:
raise ValueError(error_message)
return data
async def _send_one_way_rpc_request(self,
request: RPC_REQUEST_TYPE,
error_message: str,
socket: Optional[Socket] = None):
"""Send one-way RPC request to trigger an action."""
async def do_rpc_call(socket: Socket, request: RPC_REQUEST_TYPE):
await socket.send_multipart((cloudpickle.dumps(request), ))
if await socket.poll(timeout=self._data_timeout) == 0:
raise TimeoutError("Server didn't reply within "
f"{self._data_timeout} ms")
frame = await socket.recv(copy=False)
assert isinstance(frame, Frame)
return pickle.loads(frame.buffer)
# Make a new socket connection.
if socket is None:
with self.to_proxy_socket() as socket:
response = await do_rpc_call(socket, request)
# Use existing socket connection.
else:
response = await do_rpc_call(socket, request)
if not isinstance(response, str) or response != VLLM_RPC_SUCCESS_STR:
if isinstance(response, Exception):
logger.error(error_message)
raise response
raise ValueError(error_message)
async def get_tokenizer(self, lora_request: LoRARequest):
return await self.tokenizer.get_lora_tokenizer_async(lora_request)
async def get_decoding_config(self) -> DecodingConfig:
return self.decoding_config
async def get_model_config(self) -> ModelConfig:
return self.model_config
async def is_tracing_enabled(self) -> bool:
return self.tracing_flag
async def _wait_for_server_rpc(self):
"""Wait for the RPCServer to start up."""
await self._send_one_way_rpc_request(
request=RPCUtilityRequest.IS_SERVER_READY,
error_message="Unable to start RPC Server")
async def _get_model_config_rpc(self) -> ModelConfig:
"""Get the ModelConfig object from the RPC Server"""
return await self._send_get_data_rpc_request(
RPCUtilityRequest.GET_MODEL_CONFIG,
expected_type=ModelConfig,
error_message="Could not get ModelConfig from RPC Server")
async def _get_decoding_config_rpc(self) -> DecodingConfig:
"""Get DecodingConfig from the RPCServer"""
return await self._send_get_data_rpc_request(
RPCUtilityRequest.GET_DECODING_CONFIG,
expected_type=DecodingConfig,
error_message="Could not get DecodingConfig from RPC Server")
async def _get_parallel_config_rpc(self) -> ParallelConfig:
"""Get ParallelConfig from the RPCServer"""
return await self._send_get_data_rpc_request(
RPCUtilityRequest.GET_PARALLEL_CONFIG,
expected_type=ParallelConfig,
error_message="Could not get ParallelConfig from RPC Server")
async def _get_scheduler_config_rpc(self) -> SchedulerConfig:
"""Get SchedulerConfig from the RPCServer"""
return await self._send_get_data_rpc_request(
RPCUtilityRequest.GET_SCHEDULER_CONFIG,
expected_type=SchedulerConfig,
error_message="Could not get SchedulerConfig from RPC Server")
async def _get_lora_config_rpc(self) -> LoRAConfig:
"""Get LoRAConfig from the RPCServer"""
return await self._send_get_data_rpc_request(
RPCUtilityRequest.GET_LORA_CONFIG,
expected_type=LoRAConfig,
error_message="Could not get LoRAConfig from RPC Server")
async def _is_tracing_enabled_rpc(self) -> bool:
"""Get is_tracing_enabled flag from the RPCServer"""
return await self._send_get_data_rpc_request(
RPCUtilityRequest.IS_TRACING_ENABLED,
expected_type=bool,
error_message="Could not get is_tracing_enabled from RPC Server")
async def abort(self, request_id: str):
"""Send an ABORT_REQUEST signal to the RPC Server"""
# Suppress timeouts as well.
# In cases where the server is busy processing requests and a very
# large volume of abort requests arrive, it is likely that the server
# will not be able to ack all of them in time. We have seen this when
# we abort 20k requests at once while another 2k are processing- many
# of them time out, but we see the server successfully abort all of the
# requests.
# In this case we assume that the server has received or will receive
# these abort requests, and ignore the timeout. This prevents a massive
# wall of `TimeoutError` stack traces.
with suppress(RPCClientClosedError, TimeoutError):
await self._send_one_way_rpc_request(
request=RPCAbortRequest(request_id),
error_message=f"RPCAbortRequest {request_id} failed")
async def do_log_stats(self):
"""Send a DO_LOG_STATS signal to the RPC Server"""
with suppress(RPCClientClosedError):
await self._send_one_way_rpc_request(
request=RPCUtilityRequest.DO_LOG_STATS,
error_message="RPCRequest DO_LOG_STATS failed.")
@property
def is_running(self) -> bool:
return not self._errored
@property
def is_stopped(self) -> bool:
return self._errored
@property
def errored(self) -> bool:
return self._errored
async def generate(
self,
inputs: PromptInputs,
sampling_params: SamplingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
) -> AsyncGenerator[RequestOutput, None]:
"""Send an RPCGenerateRequest to the RPCServer and stream responses."""
finished = False
try:
with self.to_proxy_socket() as socket:
# Send RPCGenerateRequest to the RPCServer.
await socket.send_multipart((cloudpickle.dumps(
RPCGenerateRequest(
inputs=inputs,
sampling_params=sampling_params,
request_id=request_id,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request)), ))
# Stream back the results from the RPC Server.
while not finished:
message = await socket.recv(copy=False)
assert isinstance(message, Frame)
request_output = pickle.loads(message.buffer)
if isinstance(request_output, Exception):
# On exception, check if the server is still healthy
# possibly setting the `errored` property.
if not self._errored:
try:
await self.check_health(socket=socket)
except Exception as e:
self._errored = True
logger.exception(repr(e))
# NB: do before raising here so that the flag is set
# by the time the caller receives this exception
raise request_output
finished = request_output.finished
yield request_output
finally:
# Request was canceled by the client.
if not finished and not self._errored:
await self.abort(request_id)
async def check_health(self, socket: Optional[Socket] = None) -> None:
"""Raise if unhealthy"""
await self._send_one_way_rpc_request(
request=RPCUtilityRequest.IS_SERVER_HEALTHY,
error_message="Got Unhealthy response from RPC Server",
socket=socket)
async def encode(self, *args,
**kwargs) -> AsyncGenerator[EmbeddingRequestOutput, None]:
raise NotImplementedError(
"Embeddings not supported with multiprocessing backend")
async def start_profile(self) -> None:
"""Start profiling the engine"""
await self._send_one_way_rpc_request(
request=RPCUtilityRequest.START_PROFILE,
error_message="RPCRequest START_PROFILE failed.")
async def stop_profile(self) -> None:
"""Stop profiling the engine"""
await self._send_one_way_rpc_request(
request=RPCUtilityRequest.STOP_PROFILE,
error_message="RPCRequest STOP_PROFILE failed.")
import asyncio
import pickle
import signal
from typing import Any, Coroutine, Union
import cloudpickle
import uvloop
import zmq
import zmq.asyncio
from typing_extensions import Never
from zmq import Frame # type: ignore[attr-defined]
from zmq.asyncio import Socket
from vllm import AsyncEngineArgs, AsyncLLMEngine
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig)
from vllm.entrypoints.openai.rpc import (VLLM_RPC_SUCCESS_STR,
VLLM_RPC_ZMQ_HWM, RPCAbortRequest,
RPCGenerateRequest, RPCUtilityRequest)
from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext
logger = init_logger(__name__)
CONFIG_TYPE = Union[ModelConfig, DecodingConfig, ParallelConfig,
SchedulerConfig, LoRAConfig]
class AsyncEngineRPCServer:
def __init__(self, async_engine_args: AsyncEngineArgs,
usage_context: UsageContext, rpc_path: str):
# Initialize engine first.
self.engine = AsyncLLMEngine.from_engine_args(
async_engine_args, usage_context=usage_context)
# Initialize context.
self.context = zmq.asyncio.Context()
# Init socket.
self.socket: Socket = self.context.socket(zmq.constants.DEALER)
self.socket.set_hwm(VLLM_RPC_ZMQ_HWM)
self.socket.connect(rpc_path)
def cleanup(self):
"""Cleanup all resources."""
self.socket.close()
self.context.destroy()
self.engine.shutdown_background_loop()
# Clear the engine reference so that it can be GC'ed.
del self.engine
async def get_config(self, identity, request):
try:
config: CONFIG_TYPE
if request == RPCUtilityRequest.GET_MODEL_CONFIG:
config = await self.engine.get_model_config()
elif request == RPCUtilityRequest.GET_DECODING_CONFIG:
config = await self.engine.get_decoding_config()
elif request == RPCUtilityRequest.GET_LORA_CONFIG:
config = await self.engine.get_lora_config()
elif request == RPCUtilityRequest.GET_SCHEDULER_CONFIG:
config = await self.engine.get_scheduler_config()
elif request == RPCUtilityRequest.GET_PARALLEL_CONFIG:
config = await self.engine.get_parallel_config()
else:
raise ValueError("Unknown Config Request: %s", request)
await self.socket.send_multipart((identity, pickle.dumps(config)),
copy=False)
except Exception as e:
await self.socket.send_multipart((identity, pickle.dumps(e)),
copy=False)
async def is_tracing_enabled(self, identity):
"""Send the is_tracing_enabled flag"""
tracing_flag = await self.engine.is_tracing_enabled()
await self.socket.send_multipart(
(identity, pickle.dumps(tracing_flag)))
async def do_log_stats(self, identity):
"""Log stats and confirm success."""
await self.engine.do_log_stats()
await self.socket.send_multipart(
(identity, pickle.dumps(VLLM_RPC_SUCCESS_STR)))
async def is_server_ready(self, identity):
"""Notify the client that we are ready."""
await self.socket.send_multipart(
(identity, pickle.dumps(VLLM_RPC_SUCCESS_STR)))
async def abort(self, identity, request: RPCAbortRequest):
"""Abort request and notify the client of success."""
try:
# Abort the request in the llm engine.
await self.engine.abort(request.request_id)
result: Union[str, Exception] = VLLM_RPC_SUCCESS_STR
except Exception as e:
result = e
await self.socket.send_multipart((identity, pickle.dumps(result)))
async def generate(self, identity, generate_request: RPCGenerateRequest):
try:
results_generator = self.engine.generate(
generate_request.inputs,
sampling_params=generate_request.sampling_params,
request_id=generate_request.request_id,
lora_request=generate_request.lora_request,
trace_headers=generate_request.trace_headers,
prompt_adapter_request=generate_request.prompt_adapter_request)
async for request_output in results_generator:
await self.socket.send_multipart(
(identity, pickle.dumps(request_output)), copy=False)
except Exception as e:
await self.socket.send_multipart((identity, pickle.dumps(e)),
copy=False)
async def check_health(self, identity):
try:
await self.engine.check_health()
await self.socket.send_multipart(
(identity, pickle.dumps(VLLM_RPC_SUCCESS_STR)))
except Exception as e:
await self.socket.send_multipart((identity, pickle.dumps(e)),
copy=False)
async def start_profile(self, identity):
logger.info("Starting profiler...")
await self.engine.start_profile()
logger.info("Profiler started.")
await self.socket.send_multipart((
identity,
pickle.dumps(VLLM_RPC_SUCCESS_STR),
))
async def stop_profile(self, identity):
logger.info("Stopping profiler...")
await self.engine.stop_profile()
logger.info("Profiler stopped.")
await self.socket.send_multipart((
identity,
pickle.dumps(VLLM_RPC_SUCCESS_STR),
))
def _make_handler_coro(self, identity,
message: Frame) -> Coroutine[Any, Any, Never]:
"""Route the zmq message to the handler coroutine."""
request = cloudpickle.loads(message.buffer)
if isinstance(request, RPCGenerateRequest):
return self.generate(identity, request)
elif isinstance(request, RPCAbortRequest):
return self.abort(identity, request)
elif isinstance(request, RPCUtilityRequest):
if request in [
RPCUtilityRequest.GET_MODEL_CONFIG,
RPCUtilityRequest.GET_PARALLEL_CONFIG,
RPCUtilityRequest.GET_DECODING_CONFIG,
RPCUtilityRequest.GET_SCHEDULER_CONFIG,
RPCUtilityRequest.GET_LORA_CONFIG
]:
return self.get_config(identity, request)
elif request == RPCUtilityRequest.DO_LOG_STATS:
return self.do_log_stats(identity)
elif request == RPCUtilityRequest.IS_SERVER_READY:
return self.is_server_ready(identity)
elif request == RPCUtilityRequest.IS_SERVER_HEALTHY:
return self.check_health(identity)
elif request == RPCUtilityRequest.IS_TRACING_ENABLED:
return self.is_tracing_enabled(identity)
elif request == RPCUtilityRequest.START_PROFILE:
return self.start_profile(identity)
elif request == RPCUtilityRequest.STOP_PROFILE:
return self.stop_profile(identity)
else:
raise ValueError(f"Unknown RPCUtilityRequest type: {request}")
else:
raise ValueError(f"Unknown RPCRequest type: {request}")
async def run_server_loop(self):
"""Inner RPC Server Loop"""
running_tasks = set()
while True:
# Wait for a request.
identity, message = await self.socket.recv_multipart(copy=False)
# Process the request async.
task = asyncio.create_task(
self._make_handler_coro(identity, message))
# We need to keep around a strong reference to the task,
# to avoid the task disappearing mid-execution as running tasks
# can be GC'ed. Below is a common "fire-and-forget" tasks
# https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task
running_tasks.add(task)
task.add_done_callback(running_tasks.discard)
async def run_server(server: AsyncEngineRPCServer):
# Put the server task into the asyncio loop.
loop = asyncio.get_running_loop()
server_task = loop.create_task(server.run_server_loop())
# Interruption handling.
def signal_handler() -> None:
# Kill the server on interrupt / terminate
server_task.cancel()
loop.add_signal_handler(signal.SIGINT, signal_handler)
loop.add_signal_handler(signal.SIGTERM, signal_handler)
try:
await server_task
except asyncio.CancelledError:
logger.info("vLLM ZMQ RPC Server was interrupted.")
finally:
# Clean up all resources.
server.cleanup()
def run_rpc_server(async_engine_args: AsyncEngineArgs,
usage_context: UsageContext, rpc_path: str):
server = AsyncEngineRPCServer(async_engine_args, usage_context, rpc_path)
uvloop.run(run_server(server))
...@@ -20,6 +20,7 @@ from vllm.entrypoints.openai.protocol import (BatchRequestInput, ...@@ -20,6 +20,7 @@ from vllm.entrypoints.openai.protocol import (BatchRequestInput,
# yapf: enable # yapf: enable
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
from vllm.entrypoints.openai.serving_engine import BaseModelPath
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser, random_uuid from vllm.utils import FlexibleArgumentParser, random_uuid
from vllm.version import __version__ as VLLM_VERSION from vllm.version import __version__ as VLLM_VERSION
...@@ -196,6 +197,10 @@ async def main(args): ...@@ -196,6 +197,10 @@ async def main(args):
engine_args, usage_context=UsageContext.OPENAI_BATCH_RUNNER) engine_args, usage_context=UsageContext.OPENAI_BATCH_RUNNER)
model_config = await engine.get_model_config() model_config = await engine.get_model_config()
base_model_paths = [
BaseModelPath(name=name, model_path=args.model)
for name in served_model_names
]
if args.disable_log_requests: if args.disable_log_requests:
request_logger = None request_logger = None
...@@ -206,7 +211,7 @@ async def main(args): ...@@ -206,7 +211,7 @@ async def main(args):
openai_serving_chat = OpenAIServingChat( openai_serving_chat = OpenAIServingChat(
engine, engine,
model_config, model_config,
served_model_names, base_model_paths,
args.response_role, args.response_role,
lora_modules=None, lora_modules=None,
prompt_adapters=None, prompt_adapters=None,
...@@ -216,7 +221,7 @@ async def main(args): ...@@ -216,7 +221,7 @@ async def main(args):
openai_serving_embedding = OpenAIServingEmbedding( openai_serving_embedding = OpenAIServingEmbedding(
engine, engine,
model_config, model_config,
served_model_names, base_model_paths,
request_logger=request_logger, request_logger=request_logger,
) )
......
...@@ -9,7 +9,7 @@ from typing import Union ...@@ -9,7 +9,7 @@ from typing import Union
from fastapi import Request from fastapi import Request
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.engine.protocol import AsyncEngineClient from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import (ConversationMessage, from vllm.entrypoints.chat_utils import (ConversationMessage,
apply_hf_chat_template, apply_hf_chat_template,
apply_mistral_chat_template, apply_mistral_chat_template,
...@@ -22,8 +22,10 @@ from vllm.entrypoints.openai.protocol import ( ...@@ -22,8 +22,10 @@ from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest, ChatCompletionResponse, ChatCompletionRequest, ChatCompletionResponse,
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage, ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage,
DeltaToolCall, ErrorResponse, FunctionCall, ToolCall, UsageInfo) DeltaToolCall, ErrorResponse, FunctionCall, RequestResponseMetadata,
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, ToolCall, UsageInfo)
from vllm.entrypoints.openai.serving_engine import (BaseModelPath,
LoRAModulePath,
OpenAIServing, OpenAIServing,
PromptAdapterPath, PromptAdapterPath,
TextTokensPrompt) TextTokensPrompt)
...@@ -45,9 +47,9 @@ logger = init_logger(__name__) ...@@ -45,9 +47,9 @@ logger = init_logger(__name__)
class OpenAIServingChat(OpenAIServing): class OpenAIServingChat(OpenAIServing):
def __init__(self, def __init__(self,
async_engine_client: AsyncEngineClient, engine_client: EngineClient,
model_config: ModelConfig, model_config: ModelConfig,
served_model_names: List[str], base_model_paths: List[BaseModelPath],
response_role: str, response_role: str,
*, *,
lora_modules: Optional[List[LoRAModulePath]], lora_modules: Optional[List[LoRAModulePath]],
...@@ -57,9 +59,9 @@ class OpenAIServingChat(OpenAIServing): ...@@ -57,9 +59,9 @@ class OpenAIServingChat(OpenAIServing):
return_tokens_as_token_ids: bool = False, return_tokens_as_token_ids: bool = False,
enable_auto_tools: bool = False, enable_auto_tools: bool = False,
tool_parser: Optional[str] = None): tool_parser: Optional[str] = None):
super().__init__(async_engine_client=async_engine_client, super().__init__(engine_client=engine_client,
model_config=model_config, model_config=model_config,
served_model_names=served_model_names, base_model_paths=base_model_paths,
lora_modules=lora_modules, lora_modules=lora_modules,
prompt_adapters=prompt_adapters, prompt_adapters=prompt_adapters,
request_logger=request_logger, request_logger=request_logger,
...@@ -105,6 +107,12 @@ class OpenAIServingChat(OpenAIServing): ...@@ -105,6 +107,12 @@ class OpenAIServingChat(OpenAIServing):
logger.error("Error with model %s", error_check_ret) logger.error("Error with model %s", error_check_ret)
return error_check_ret return error_check_ret
# If the engine is dead, raise the engine's DEAD_ERROR.
# This is required for the streaming case, where we return a
# success status before we actually start generating text :).
if self.engine_client.errored:
raise self.engine_client.dead_error
try: try:
( (
lora_request, lora_request,
...@@ -112,8 +120,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -112,8 +120,7 @@ class OpenAIServingChat(OpenAIServing):
) = self._maybe_get_adapters(request) ) = self._maybe_get_adapters(request)
model_config = self.model_config model_config = self.model_config
tokenizer = await self.async_engine_client.get_tokenizer( tokenizer = await self.engine_client.get_tokenizer(lora_request)
lora_request)
conversation, mm_data_future = parse_chat_messages_futures( conversation, mm_data_future = parse_chat_messages_futures(
request.messages, model_config, tokenizer) request.messages, model_config, tokenizer)
...@@ -123,7 +130,8 @@ class OpenAIServingChat(OpenAIServing): ...@@ -123,7 +130,8 @@ class OpenAIServingChat(OpenAIServing):
] ]
prompt: Union[str, List[int]] prompt: Union[str, List[int]]
if isinstance(tokenizer, MistralTokenizer): is_mistral_tokenizer = isinstance(tokenizer, MistralTokenizer)
if is_mistral_tokenizer:
prompt = apply_mistral_chat_template( prompt = apply_mistral_chat_template(
tokenizer, tokenizer,
messages=request.messages, messages=request.messages,
...@@ -159,15 +167,20 @@ class OpenAIServingChat(OpenAIServing): ...@@ -159,15 +167,20 @@ class OpenAIServingChat(OpenAIServing):
return self.create_error_response( return self.create_error_response(
"tool_choice = \"required\" is not supported!") "tool_choice = \"required\" is not supported!")
# "auto" tools requires --enable-auto-tool-choice if not is_mistral_tokenizer and request.tool_choice == "auto" and not (
# and --tool-call-parser
if request.tool_choice == "auto" and not (
self.enable_auto_tools and self.tool_parser is not None): self.enable_auto_tools and self.tool_parser is not None):
# for hf tokenizers, "auto" tools requires
# --enable-auto-tool-choice and --tool-call-parser
return self.create_error_response( return self.create_error_response(
"\"auto\" tool choice requires " "\"auto\" tool choice requires "
"--enable-auto-tool-choice and --tool-call-parser to be set") "--enable-auto-tool-choice and --tool-call-parser to be set")
request_id = f"chat-{random_uuid()}" request_id = f"chat-{random_uuid()}"
request_metadata = RequestResponseMetadata(request_id=request_id)
if raw_request:
raw_request.state.request_metadata = request_metadata
try: try:
guided_decode_logits_processor = ( guided_decode_logits_processor = (
await self._guided_decode_logits_processor(request, tokenizer)) await self._guided_decode_logits_processor(request, tokenizer))
...@@ -206,8 +219,8 @@ class OpenAIServingChat(OpenAIServing): ...@@ -206,8 +219,8 @@ class OpenAIServingChat(OpenAIServing):
if mm_data is not None: if mm_data is not None:
engine_inputs["multi_modal_data"] = mm_data engine_inputs["multi_modal_data"] = mm_data
is_tracing_enabled = ( is_tracing_enabled = (await
await self.async_engine_client.is_tracing_enabled()) self.engine_client.is_tracing_enabled())
trace_headers = None trace_headers = None
if is_tracing_enabled and raw_request: if is_tracing_enabled and raw_request:
trace_headers = extract_trace_headers(raw_request.headers) trace_headers = extract_trace_headers(raw_request.headers)
...@@ -215,7 +228,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -215,7 +228,7 @@ class OpenAIServingChat(OpenAIServing):
and contains_trace_headers(raw_request.headers)): and contains_trace_headers(raw_request.headers)):
log_tracing_disabled_warning() log_tracing_disabled_warning()
result_generator = self.async_engine_client.generate( result_generator = self.engine_client.generate(
engine_inputs, engine_inputs,
sampling_params, sampling_params,
request_id, request_id,
...@@ -234,11 +247,13 @@ class OpenAIServingChat(OpenAIServing): ...@@ -234,11 +247,13 @@ class OpenAIServingChat(OpenAIServing):
# Streaming response # Streaming response
if request.stream: if request.stream:
return self.chat_completion_stream_generator( return self.chat_completion_stream_generator(
request, result_generator, request_id, conversation, tokenizer) request, result_generator, request_id, conversation, tokenizer,
request_metadata)
try: try:
return await self.chat_completion_full_generator( return await self.chat_completion_full_generator(
request, result_generator, request_id, conversation, tokenizer) request, result_generator, request_id, conversation, tokenizer,
request_metadata)
except ValueError as e: except ValueError as e:
# TODO: Use a vllm-specific Validation Error # TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e)) return self.create_error_response(str(e))
...@@ -255,8 +270,9 @@ class OpenAIServingChat(OpenAIServing): ...@@ -255,8 +270,9 @@ class OpenAIServingChat(OpenAIServing):
request_id: str, request_id: str,
conversation: List[ConversationMessage], conversation: List[ConversationMessage],
tokenizer: AnyTokenizer, tokenizer: AnyTokenizer,
request_metadata: RequestResponseMetadata,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
model_name = self.served_model_names[0] model_name = self.base_model_paths[0].name
created_time = int(time.time()) created_time = int(time.time())
chunk_object_type: Final = "chat.completion.chunk" chunk_object_type: Final = "chat.completion.chunk"
first_iteration = True first_iteration = True
...@@ -293,6 +309,8 @@ class OpenAIServingChat(OpenAIServing): ...@@ -293,6 +309,8 @@ class OpenAIServingChat(OpenAIServing):
async for res in result_generator: async for res in result_generator:
if res.prompt_token_ids is not None: if res.prompt_token_ids is not None:
num_prompt_tokens = len(res.prompt_token_ids) num_prompt_tokens = len(res.prompt_token_ids)
if res.encoder_prompt_token_ids is not None:
num_prompt_tokens += len(res.encoder_prompt_token_ids)
# We need to do it here, because if there are exceptions in # We need to do it here, because if there are exceptions in
# the result_generator, it needs to be sent as the FIRST # the result_generator, it needs to be sent as the FIRST
...@@ -573,6 +591,13 @@ class OpenAIServingChat(OpenAIServing): ...@@ -573,6 +591,13 @@ class OpenAIServingChat(OpenAIServing):
exclude_unset=True, exclude_none=True)) exclude_unset=True, exclude_none=True))
yield f"data: {final_usage_data}\n\n" yield f"data: {final_usage_data}\n\n"
# report to FastAPI middleware aggregate usage across all choices
num_completion_tokens = sum(previous_num_tokens)
request_metadata.final_usage_info = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=num_completion_tokens,
total_tokens=num_prompt_tokens + num_completion_tokens)
except ValueError as e: except ValueError as e:
# TODO: Use a vllm-specific Validation Error # TODO: Use a vllm-specific Validation Error
logger.error("error in chat completion stream generator: %s", e) logger.error("error in chat completion stream generator: %s", e)
...@@ -588,9 +613,10 @@ class OpenAIServingChat(OpenAIServing): ...@@ -588,9 +613,10 @@ class OpenAIServingChat(OpenAIServing):
request_id: str, request_id: str,
conversation: List[ConversationMessage], conversation: List[ConversationMessage],
tokenizer: AnyTokenizer, tokenizer: AnyTokenizer,
request_metadata: RequestResponseMetadata,
) -> Union[ErrorResponse, ChatCompletionResponse]: ) -> Union[ErrorResponse, ChatCompletionResponse]:
model_name = self.served_model_names[0] model_name = self.base_model_paths[0].name
created_time = int(time.time()) created_time = int(time.time())
final_res: Optional[RequestOutput] = None final_res: Optional[RequestOutput] = None
...@@ -707,6 +733,9 @@ class OpenAIServingChat(OpenAIServing): ...@@ -707,6 +733,9 @@ class OpenAIServingChat(OpenAIServing):
completion_tokens=num_generated_tokens, completion_tokens=num_generated_tokens,
total_tokens=num_prompt_tokens + num_generated_tokens, total_tokens=num_prompt_tokens + num_generated_tokens,
) )
request_metadata.final_usage_info = usage
response = ChatCompletionResponse( response = ChatCompletionResponse(
id=request_id, id=request_id,
created=created_time, created=created_time,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment