Commit af7f4372 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.5.5' into v0.5.5-dtk24.04.1

parents 5e19cdef 09c77926
import asyncio
import time
from dataclasses import dataclass
from functools import partial
from typing import (AsyncIterator, Callable, Dict, Iterable, List, Mapping,
Optional, Set, Tuple, Type, Union)
from typing import (Any, AsyncGenerator, Callable, Dict, Iterable, List,
Mapping, Optional, Set, Tuple, Type, Union)
from transformers import PreTrainedTokenizer
import torch
from typing_extensions import assert_never
import vllm.envs as envs
from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig,
......@@ -12,19 +14,25 @@ from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig,
from vllm.core.scheduler import SchedulerOutputs
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_timeout import asyncio_timeout
from vllm.engine.llm_engine import LLMEngine
from vllm.engine.metrics import StatLoggerBase
from vllm.engine.llm_engine import (DecoderPromptComponents, LLMEngine,
PromptComponents)
from vllm.engine.metrics_types import StatLoggerBase
from vllm.executor.executor_base import ExecutorAsyncBase
from vllm.executor.ray_utils import initialize_ray_cluster, ray
from vllm.inputs import LLMInputs, PromptInputs
from vllm.inputs import (EncoderDecoderLLMInputs, LLMInputs, PromptInputs,
SingletonPromptInputs)
from vllm.inputs.parse import is_explicit_encoder_decoder_prompt
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.sequence import (ExecuteModelRequest, SamplerOutput,
SequenceGroupMetadata)
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.usage.usage_lib import UsageContext
from vllm.utils import print_warning_once
logger = init_logger(__name__)
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
......@@ -58,41 +66,61 @@ def _log_task_completion(task: asyncio.Task,
error_callback(exception)
raise AsyncEngineDeadError(
"Task finished unexpectedly. This should never happen! "
"Please open an issue on Github. See stack trace above for the"
"Please open an issue on Github. See stack trace above for the "
"actual cause.") from e
STOP_ITERATION = Exception() # Sentinel
class AsyncStream:
"""A stream of RequestOutputs or EmbeddingRequestOutputs for a request
that can be iterated over asynchronously."""
that can be iterated over asynchronously via an async generator."""
def __init__(self, request_id: str) -> None:
def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None:
self.request_id = request_id
self._cancel = cancel
self._queue: asyncio.Queue = asyncio.Queue()
self._finished = False
def put(self, item: Union[RequestOutput, EmbeddingRequestOutput,
Exception]) -> None:
if self._finished:
return
if not self._finished:
self._queue.put_nowait(item)
def finish(self) -> None:
self._queue.put_nowait(StopAsyncIteration())
def finish(
self,
exception: Optional[Union[BaseException, Type[BaseException]]] = None,
) -> None:
if not self._finished:
self._finished = True
self._queue.put_nowait(
exception if self._is_raisable(exception) else STOP_ITERATION)
@property
def finished(self) -> bool:
return self._finished
def __aiter__(self):
return self
async def __anext__(self) -> Union[RequestOutput, EmbeddingRequestOutput]:
async def generator(
self
) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
try:
while True:
result = await self._queue.get()
if isinstance(result, Exception):
if self._is_raisable(result):
if result == STOP_ITERATION:
return
raise result
return result
yield result
except GeneratorExit:
self._cancel(self.request_id)
raise asyncio.CancelledError from None
@staticmethod
def _is_raisable(value: Any):
return isinstance(value, BaseException) or \
(isinstance(value, type) and \
issubclass(value, BaseException))
class RequestTracker:
......@@ -100,7 +128,7 @@ class RequestTracker:
def __init__(self) -> None:
self._request_streams: Dict[str, AsyncStream] = {}
self._finished_requests: asyncio.Queue[str] = asyncio.Queue()
self._aborted_requests: asyncio.Queue[str] = asyncio.Queue()
self._new_requests: asyncio.Queue[Tuple[AsyncStream,
dict]] = asyncio.Queue()
self.new_requests_event = asyncio.Event()
......@@ -117,12 +145,12 @@ class RequestTracker:
"""Propagate an exception to request streams
(all if request_id is None)."""
if request_id is not None:
self._request_streams[request_id].put(exc)
self.abort_request(request_id)
self.abort_request(request_id, exception=exc)
else:
for rid, stream in self._request_streams.items():
stream.put(exc)
self.abort_request(rid)
# NB: tuple() used here because self.abort_request pops the stream
# out of self._request_streams, so we can't iterate on it directly
for rid in tuple(self._request_streams.keys()):
self.abort_request(rid, exception=exc)
def process_request_output(self,
request_output: Union[RequestOutput,
......@@ -131,26 +159,31 @@ class RequestTracker:
verbose: bool = False) -> None:
"""Process a request output from the engine."""
request_id = request_output.request_id
finished = request_output.finished
if finished:
stream = self._request_streams.pop(request_id, None)
else:
stream = self._request_streams.get(request_id)
# Guard against a KeyError which can occur if the request was aborted
# while the output was generated
if (stream := self._request_streams.get(request_id)) is not None:
if stream is not None:
stream.put(request_output)
if request_output.finished:
if verbose:
if finished:
stream.finish()
if verbose and finished:
logger.info("Finished request %s.", request_id)
self.abort_request(request_id)
def process_exception(self,
request_id: str,
exception: Exception,
exception: BaseException,
*,
verbose: bool = False) -> None:
"""Propagate an exception from the engine."""
self._request_streams[request_id].put(exception)
if verbose:
logger.info("Finished request %s.", request_id)
self.abort_request(request_id)
self.abort_request(request_id, exception=exception)
def add_request(self,
request_id: str,
......@@ -162,7 +195,8 @@ class RequestTracker:
if request_id in self._request_streams:
raise KeyError(f"Request {request_id} already exists.")
stream = AsyncStream(request_id)
abort_request = partial(self.abort_request, verbose=verbose)
stream = AsyncStream(request_id, abort_request)
self._new_requests.put_nowait((stream, {
"request_id": request_id,
**engine_add_request_kwargs
......@@ -175,38 +209,41 @@ class RequestTracker:
return stream
def abort_request(self, request_id: str, *, verbose: bool = False) -> None:
def abort_request(self,
request_id: str,
*,
exception: Optional[Union[BaseException,
Type[BaseException]]] = None,
verbose: bool = False) -> None:
"""Abort a request during next background loop iteration."""
if verbose:
logger.info("Aborted request %s.", request_id)
self._finished_requests.put_nowait(request_id)
self._aborted_requests.put_nowait(request_id)
if request_id not in self._request_streams or self._request_streams[
request_id].finished:
# The request has already finished or been aborted.
return
self._request_streams[request_id].finish()
stream = self._request_streams.pop(request_id, None)
if stream is not None:
stream.finish(exception=exception)
def get_new_and_finished_requests(self) -> Tuple[List[Dict], Set[str]]:
def get_new_and_aborted_requests(self) -> Tuple[List[Dict], Set[str]]:
"""Get the new requests and finished requests to be
sent to the engine."""
new_requests: List[Dict] = []
finished_requests: Set[str] = set()
while not self._finished_requests.empty():
request_id = self._finished_requests.get_nowait()
while not self._aborted_requests.empty():
request_id = self._aborted_requests.get_nowait()
finished_requests.add(request_id)
self._request_streams.pop(request_id, None)
while not self._new_requests.empty():
stream, new_request = self._new_requests.get_nowait()
if stream.request_id in finished_requests:
request_id = stream.request_id
if request_id in finished_requests:
# The request has already been aborted.
stream.finish()
continue
self._request_streams[stream.request_id] = stream
stream.finish(asyncio.CancelledError)
finished_requests.discard(request_id)
else:
self._request_streams[request_id] = stream
new_requests.append(new_request)
return new_requests, finished_requests
......@@ -220,9 +257,25 @@ class RequestTracker:
return not self._new_requests.empty()
@dataclass
class SchedulerOutputState:
"""Caches the scheduler outputs for a virtual engine. Used for Multi-Step"""
last_output: Optional[SamplerOutput] = None
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None
scheduler_outputs: Optional[SchedulerOutputs] = None
class _AsyncLLMEngine(LLMEngine):
"""Extension of LLMEngine to add async methods."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
pipeline_parallel_size = \
self.parallel_config.pipeline_parallel_size
self.cached_scheduler_outputs = [
SchedulerOutputState() for _ in range(pipeline_parallel_size)
]
async def step_async(
self, virtual_engine: int
) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
......@@ -235,13 +288,39 @@ class _AsyncLLMEngine(LLMEngine):
and updates the scheduler with the model outputs. Finally, it decodes
the sequences and returns the newly generated results.
"""
# these are cached outputs from previous iterations. None if on first
# iteration
cached_outputs = self.cached_scheduler_outputs[virtual_engine]
seq_group_metadata_list = cached_outputs.seq_group_metadata_list
scheduler_outputs = cached_outputs.scheduler_outputs
# skip the scheduler if there are any remaining steps in the seq groups.
# This ensures that the scheduler is only called again when the current
# batch has completed.
if not self._has_remaining_steps(seq_group_metadata_list):
seq_group_metadata_list, scheduler_outputs = self.scheduler[
virtual_engine].schedule()
if (self.scheduler_config.is_multi_step
and scheduler_outputs.num_lookahead_slots > 0):
# cache the scheduler outputs for the next iteration if we have
# lookahead slots
self._cache_scheduler_outputs_for_multi_step(
virtual_engine, seq_group_metadata_list, scheduler_outputs)
assert seq_group_metadata_list is not None
assert scheduler_outputs is not None
if not scheduler_outputs.is_empty():
# Execute the model.
finished_requests_ids = self.scheduler[
virtual_engine].get_and_reset_finished_requests_ids()
# Check if we have a cached last_output from the previous iteration.
# For supporting PP this is probably the best way to pass the
# sampled_token_ids, as a separate broadcast over all the PP stages
# will cause one virtual engine's microbatch to block the pipeline.
last_sampled_token_ids = \
self._get_last_sampled_token_ids(virtual_engine)
execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
......@@ -250,15 +329,35 @@ class _AsyncLLMEngine(LLMEngine):
virtual_engine=virtual_engine,
num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
running_queue_size=scheduler_outputs.running_queue_size,
finished_requests_ids=finished_requests_ids)
finished_requests_ids=finished_requests_ids,
# We use ExecuteModelRequest to pass the last sampled_token_ids
# to each of the non-last PP stages for in-place prepare_input.
last_sampled_token_ids=last_sampled_token_ids)
# Execute the model.
output = await self.model_executor.execute_model_async(
execute_model_req)
# we need to do this here so that last step's sampled_token_ids can
# be passed to the next iteration for PP.
if self.scheduler_config.is_multi_step:
self._update_cached_scheduler_output(virtual_engine, output)
else:
output = []
# Finish the current step for all the sequence groups.
if self.scheduler_config.is_multi_step:
for seq_group in seq_group_metadata_list:
seq_group.finish_step()
if not self._has_remaining_steps(seq_group_metadata_list):
# clear the cache if we have finished all the steps
if self.scheduler_config.is_multi_step:
self.cached_scheduler_outputs[
virtual_engine] = SchedulerOutputState()
request_outputs = self._process_model_outputs(
output, scheduler_outputs.scheduled_seq_groups,
scheduler_outputs.ignored_seq_groups, seq_group_metadata_list)
else:
request_outputs = []
# Log stats.
self.do_log_stats(scheduler_outputs, output)
......@@ -268,42 +367,196 @@ class _AsyncLLMEngine(LLMEngine):
return request_outputs
def _has_remaining_steps(
self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]]
) -> bool:
if (not self.scheduler_config.is_multi_step
or not seq_group_metadata_list):
return False
# TODO(will) this is a sanity check for nowto make sure that all the
# seqs are on the same steps. Eventually we will want to do some sort of
# dynamic scheduling when doing multi-step decoding.
ref_remaining_steps = seq_group_metadata_list[0].state.remaining_steps
if any([
seq_group.state.remaining_steps != ref_remaining_steps
for seq_group in seq_group_metadata_list[1:]
]):
raise AssertionError(("All running sequence groups should "
"have the same remaining steps."))
return ref_remaining_steps > 0
def _cache_scheduler_outputs_for_multi_step(
self, virtual_engine: int,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
scheduler_outputs: SchedulerOutputs) -> None:
self.cached_scheduler_outputs[
virtual_engine].seq_group_metadata_list = seq_group_metadata_list
self.cached_scheduler_outputs[virtual_engine].scheduler_outputs = \
scheduler_outputs
self.cached_scheduler_outputs[virtual_engine].last_output = None
def _get_last_sampled_token_ids(
self, virtual_engine: int) -> Optional[torch.Tensor]:
cached_last_output = self.cached_scheduler_outputs[
virtual_engine].last_output
if (self.scheduler_config.is_multi_step
and self.parallel_config.pipeline_parallel_size > 1
and cached_last_output is not None
and cached_last_output.sampled_token_ids_cpu is not None):
return cached_last_output.sampled_token_ids_cpu
return None
def _update_cached_scheduler_output(
self, virtual_engine: int,
output: List[Optional[SamplerOutput]]) -> None:
if (self.parallel_config.pipeline_parallel_size > 1 and len(output) > 0
and output[0] is not None):
last_output = output[-1]
assert last_output is not None
assert last_output.sampled_token_ids_cpu is not None
assert last_output.sampled_token_ids is None
assert last_output.sampled_token_probs is None
self.cached_scheduler_outputs[
virtual_engine].last_output = last_output
async def stop_remote_worker_execution_loop_async(self) -> None:
"""Stop the remote worker execution loop."""
await self.model_executor.stop_remote_worker_execution_loop_async()
async def process_model_inputs_async(
async def _tokenize_prompt_async(
self,
prompt: str,
request_id: str,
lora_request: Optional[LoRARequest],
) -> List[int]:
"""Async version of :meth:`_tokenize_prompt`."""
tokenizer = self.get_tokenizer_group(
missing_msg="prompts must be None if skip_tokenizer_init is True")
return await tokenizer.encode_async(request_id=request_id,
prompt=prompt,
lora_request=lora_request)
async def _extract_prompt_components_async(
self,
inputs: SingletonPromptInputs,
request_id: str,
lora_request: Optional[LoRARequest] = None,
) -> PromptComponents:
"""Async version of :meth:`_extract_prompt_components`."""
if isinstance(inputs, str):
prompt = inputs
prompt_token_ids = await self._tokenize_prompt_async(
prompt,
request_id=request_id,
lora_request=lora_request,
)
multi_modal_data = None
elif isinstance(inputs, dict):
if "prompt_token_ids" in inputs:
prompt = None
prompt_token_ids = inputs["prompt_token_ids"]
else:
# NOTE: This extra assignment is required to pass mypy
prompt = parsed_prompt = inputs["prompt"]
prompt_token_ids = await self._tokenize_prompt_async(
parsed_prompt,
request_id=request_id,
lora_request=lora_request,
)
multi_modal_data = inputs.get("multi_modal_data")
else:
assert_never(inputs)
return prompt, prompt_token_ids, multi_modal_data
async def _process_encoder_decoder_prompt_async(
self,
inputs: PromptInputs,
request_id: str,
) -> EncoderDecoderLLMInputs:
"""Async version of :meth:`_process_encoder_decoder_prompt`."""
encoder_comps: PromptComponents
decoder_comps: DecoderPromptComponents
if is_explicit_encoder_decoder_prompt(inputs):
encoder_task = self._extract_prompt_components_async(
inputs["encoder_prompt"],
request_id=request_id,
)
if (decoder_input := inputs["decoder_prompt"]) is None:
encoder_comps = await encoder_task
decoder_comps = None, None, None
else:
decoder_task = self._extract_prompt_components_async(
decoder_input,
request_id=request_id,
)
encoder_comps, decoder_comps = await asyncio.gather(
encoder_task, decoder_task)
else:
encoder_comps = await self._extract_prompt_components_async(
inputs,
request_id=request_id,
)
decoder_comps = None, None, None
return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps)
async def _process_decoder_only_prompt_async(
self,
inputs: SingletonPromptInputs,
request_id: str,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> LLMInputs:
if isinstance(inputs, str):
inputs = {"prompt": inputs}
"""Async version of :meth:`_process_decoder_only_prompt`."""
prompt_comps = await self._extract_prompt_components_async(
inputs,
request_id=request_id,
lora_request=lora_request,
)
if "prompt_token_ids" not in inputs:
tokenizer = self.get_tokenizer_group("prompts must be None if "
"skip_tokenizer_init is True")
return self._build_decoder_only_llm_inputs(
prompt_comps,
prompt_adapter_request=prompt_adapter_request,
)
prompt_token_ids = await tokenizer.encode_async(
async def process_model_inputs_async(
self,
inputs: PromptInputs,
request_id: str,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> Union[LLMInputs, EncoderDecoderLLMInputs]:
"""Async version of :meth:`process_model_inputs`."""
if self.is_encoder_decoder_model():
# Encoder-decoder model requires special mapping of
# input prompts to encoder & decoder
model_inputs = await self._process_encoder_decoder_prompt_async(
inputs,
request_id=request_id,
prompt=inputs["prompt"],
lora_request=lora_request)
)
else:
prompt_token_ids = inputs["prompt_token_ids"]
if prompt_adapter_request:
prompt_token_ids = [
0
] * prompt_adapter_request.prompt_adapter_num_virtual_tokens + \
prompt_token_ids
if is_explicit_encoder_decoder_prompt(inputs):
raise ValueError("Cannot pass encoder-decoder prompt "
"to decoder-only models")
llm_inputs = LLMInputs(prompt_token_ids=prompt_token_ids,
prompt=inputs.get("prompt"),
multi_modal_data=inputs.get("multi_modal_data"))
# Decoder-only operation
model_inputs = await self._process_decoder_only_prompt_async(
inputs,
request_id=request_id,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
)
return self.input_processor(llm_inputs)
return self.input_processor(model_inputs)
async def add_request_async(
self,
......@@ -315,6 +568,7 @@ class _AsyncLLMEngine(LLMEngine):
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> None:
"""Async version of :meth:`add_request`."""
if lora_request is not None and not self.lora_config:
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
"not enabled!")
......@@ -322,10 +576,11 @@ class _AsyncLLMEngine(LLMEngine):
arrival_time = time.time()
processed_inputs = await self.process_model_inputs_async(
inputs,
request_id=request_id,
inputs=inputs,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
prompt_adapter_request=prompt_adapter_request,
)
self._add_processed_request(
request_id=request_id,
......@@ -380,6 +635,20 @@ class AsyncLLMEngine:
self.log_requests = log_requests
self.engine = self._init_engine(*args, **kwargs)
if self.engine_use_ray:
print_warning_once(
"DEPRECATED. `--engine-use-ray` is deprecated and will "
"be removed in a future update. "
"See https://github.com/vllm-project/vllm/issues/7045.")
if envs.VLLM_ALLOW_ENGINE_USE_RAY:
print_warning_once(
"VLLM_ALLOW_ENGINE_USE_RAY is set, force engine use Ray")
else:
raise ValueError("`--engine-use-ray` is deprecated. "
"Set `VLLM_ALLOW_ENGINE_USE_RAY=1` to "
"force use it")
self.background_loop: Optional[asyncio.Future] = None
# We need to keep a reference to unshielded
# task as well to prevent it from being garbage
......@@ -497,6 +766,11 @@ class AsyncLLMEngine:
def errored(self) -> bool:
return self._errored_with is not None
@property
def limit_concurrency(self) -> Optional[int]:
"""Maximum number of concurrently running requests."""
return None
def set_errored(self, exc: Exception) -> None:
self._errored_with = exc
......@@ -507,7 +781,7 @@ class AsyncLLMEngine:
async def get_tokenizer(
self,
lora_request: Optional[LoRARequest] = None,
) -> "PreTrainedTokenizer":
) -> AnyTokenizer:
if self.engine_use_ray:
return await self.engine.get_tokenizer.remote( # type: ignore
lora_request)
......@@ -531,6 +805,20 @@ class AsyncLLMEngine:
partial(_log_task_completion, error_callback=self._error_callback))
self.background_loop = asyncio.shield(self._background_loop_unshielded)
def shutdown_background_loop(self) -> None:
"""
Shut down the background loop.
This method needs to be called during cleanup to remove
references to `self` and properly GC the resources held
by the async LLM engine (e.g., the executors as well as
their resources).
"""
if self._background_loop_unshielded is not None:
self._background_loop_unshielded.cancel()
self._background_loop_unshielded = None
self.background_loop = None
def _init_engine(self, *args,
**kwargs) -> Union[_AsyncLLMEngine, "ray.ObjectRef"]:
if not self.engine_use_ray:
......@@ -556,8 +844,8 @@ class AsyncLLMEngine:
Returns True if there are in-progress requests."""
new_requests, finished_requests = (
self._request_tracker.get_new_and_finished_requests())
new_requests, aborted_requests = (
self._request_tracker.get_new_and_aborted_requests())
for new_request in new_requests:
# Add the request into the vLLM engine's waiting queue.
......@@ -576,8 +864,8 @@ class AsyncLLMEngine:
verbose=self.log_requests,
)
if finished_requests:
await self._engine_abort(finished_requests)
if aborted_requests:
await self._engine_abort(aborted_requests)
if self.engine_use_ray:
request_outputs = await self.engine.step.remote() # type: ignore
......@@ -666,6 +954,8 @@ class AsyncLLMEngine:
raise
await asyncio.sleep(0)
# This method does not need to be async, but kept that way
# for backwards compatibility.
async def add_request(
self,
request_id: str,
......@@ -675,7 +965,7 @@ class AsyncLLMEngine:
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
) -> AsyncStream:
) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
if not self.is_running:
if self.start_engine_loop:
self.start_background_loop()
......@@ -686,20 +976,17 @@ class AsyncLLMEngine:
"error that caused the background loop to stop "
"(AsyncEngineDeadError).")
if arrival_time is None:
arrival_time = time.time()
stream = self._request_tracker.add_request(
request_id,
verbose=self.log_requests,
inputs=inputs,
params=params,
arrival_time=arrival_time,
arrival_time=arrival_time or time.time(),
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request)
return stream
return stream.generator()
async def generate(
self,
......@@ -709,7 +996,7 @@ class AsyncLLMEngine:
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
) -> AsyncIterator[RequestOutput]:
) -> AsyncGenerator[RequestOutput, None]:
"""Generate outputs for a request.
Generate outputs for a request. This method is a coroutine. It adds the
......@@ -774,7 +1061,7 @@ class AsyncLLMEngine:
>>> # Process and return the final output
>>> ...
"""
async for output in self._process_request(
async for output in await self.add_request(
request_id,
inputs,
sampling_params,
......@@ -791,7 +1078,7 @@ class AsyncLLMEngine:
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
) -> AsyncIterator[EmbeddingRequestOutput]:
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
"""Generate outputs for a request from an embedding model.
Generate outputs for a request. This method is a coroutine. It adds the
......@@ -852,7 +1139,7 @@ class AsyncLLMEngine:
>>> # Process and return the final output
>>> ...
"""
async for output in self._process_request(
async for output in await self.add_request(
request_id,
inputs,
pooling_params,
......@@ -861,37 +1148,6 @@ class AsyncLLMEngine:
):
yield LLMEngine.validate_output(output, EmbeddingRequestOutput)
async def _process_request(
self,
request_id: str,
inputs: PromptInputs,
params: Union[SamplingParams, PoolingParams],
*,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> AsyncIterator[Union[RequestOutput, EmbeddingRequestOutput]]:
"""Common logic to process requests with SamplingParams or
PoolingParams."""
arrival_time = time.time()
stream = await self.add_request(
request_id,
inputs,
params,
arrival_time=arrival_time,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
)
try:
async for request_output in stream:
yield request_output
except (Exception, asyncio.CancelledError) as e:
self._abort(request_id)
raise e
async def abort(self, request_id: str) -> None:
"""Abort a request.
......@@ -920,6 +1176,7 @@ class AsyncLLMEngine:
request_id: The unique id of the request.
"""
self._request_tracker.abort_request(request_id,
exception=asyncio.CancelledError,
verbose=self.log_requests)
async def get_model_config(self) -> ModelConfig:
......@@ -1009,3 +1266,9 @@ class AsyncLLMEngine:
logger_name=logger_name))
else:
self.engine.remove_logger(logger_name=logger_name)
async def start_profile(self) -> None:
self.engine.model_executor._run_workers("start_profile")
async def stop_profile(self) -> None:
self.engine.model_executor._run_workers("stop_profile")
......@@ -3,28 +3,33 @@ from contextlib import contextmanager
from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Iterable, List,
Mapping, Optional)
from typing import Sequence as GenericSequence
from typing import Set, Type, TypeVar, Union
from typing import Set, Tuple, Type, Union
from typing_extensions import TypeVar, assert_never
import vllm.envs as envs
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
MultiModalConfig, ObservabilityConfig, ParallelConfig,
ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig)
from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler,
SchedulerOutputs)
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.metrics import (LoggingStatLogger, PrometheusStatLogger,
StatLoggerBase, Stats)
from vllm.engine.metrics_types import StatLoggerBase, Stats
from vllm.engine.output_processor.interfaces import (
SequenceGroupOutputProcessor)
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.engine.output_processor.util import create_output_by_sequence_group
from vllm.executor.executor_base import ExecutorBase
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import INPUT_REGISTRY, LLMInputs, PromptInputs
from vllm.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs,
InputRegistry, LLMInputs, PromptInputs,
SingletonPromptInputs)
from vllm.inputs.parse import is_explicit_encoder_decoder_prompt
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.multimodal import MultiModalDataDict
from vllm.outputs import (EmbeddingRequestOutput, RequestOutput,
RequestOutputFactory)
from vllm.pooling_params import PoolingParams
......@@ -38,11 +43,12 @@ from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
init_tracer)
from vllm.transformers_utils.config import try_get_generation_config
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import (
AnyTokenizer, BaseTokenizerGroup, init_tokenizer_from_configs)
BaseTokenizerGroup, init_tokenizer_from_configs)
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message)
from vllm.utils import Counter
from vllm.utils import Counter, Device
from vllm.version import __version__ as VLLM_VERSION
logger = init_logger(__name__)
......@@ -62,8 +68,14 @@ def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]:
return config.to_diff_dict()
_G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup)
_O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput)
PromptComponents = Tuple[Optional[str], List[int],
Optional[MultiModalDataDict]]
DecoderPromptComponents = Tuple[Optional[str], Optional[List[int]],
Optional[MultiModalDataDict]]
class LLMEngine:
"""An LLM engine that receives requests and generates texts.
......@@ -89,8 +101,6 @@ class LLMEngine:
scheduler_config: The configuration related to the request scheduler.
device_config: The configuration related to the device.
lora_config (Optional): The configuration related to serving multi-LoRA.
multimodal_config (Optional): The configuration related to multimodal
models.
speculative_config (Optional): The configuration related to speculative
decoding.
executor_class: The model executor class for managing distributed
......@@ -130,24 +140,6 @@ class LLMEngine:
@classmethod
def validate_outputs(
cls,
outputs: GenericSequence[object],
output_type: Type[_O],
) -> List[_O]:
do_validate = cls.DO_VALIDATE_OUTPUT
outputs_: List[_O]
if TYPE_CHECKING or do_validate:
outputs_ = []
for output in outputs:
if not isinstance(output, output_type):
raise TypeError(f"Expected output of type {output_type}, "
f"but found type {type(output)}")
outputs_.append(output)
else:
outputs_ = outputs
return outputs_
tokenizer: Optional[BaseTokenizerGroup]
......@@ -161,7 +153,6 @@ class LLMEngine:
device_config: DeviceConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig],
speculative_config: Optional[SpeculativeConfig],
decoding_config: Optional[DecodingConfig],
observability_config: Optional[ObservabilityConfig],
......@@ -170,6 +161,7 @@ class LLMEngine:
log_stats: bool,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
input_registry: InputRegistry = INPUT_REGISTRY,
) -> None:
logger.info(
"Initializing an LLM engine (v%s) with config: "
......@@ -216,11 +208,12 @@ class LLMEngine:
cache_config.enable_prefix_caching,
)
# TODO(woosuk): Print more configs in debug mode.
from vllm.plugins import load_general_plugins
load_general_plugins()
self.model_config = model_config
self.cache_config = cache_config
self.lora_config = lora_config
self.multimodal_config = multimodal_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
......@@ -235,16 +228,26 @@ class LLMEngine:
if not self.model_config.skip_tokenizer_init:
self.tokenizer = self._init_tokenizer()
self.detokenizer = Detokenizer(self.tokenizer)
tokenizer_group = self.get_tokenizer_group()
else:
self.tokenizer = None
self.detokenizer = None
tokenizer_group = None
# Ensure that the function doesn't contain a reference to self,
# to avoid engine GC issues
def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
assert tokenizer_group, ("tokenizer_group cannot be None, "
"make sure skip_tokenizer_init is False")
return tokenizer_group.get_lora_tokenizer(sequence.lora_request)
self.seq_counter = Counter()
self.generation_config_fields = _load_generation_config_dict(
model_config)
self.input_processor = INPUT_REGISTRY.create_input_processor(
self.model_config)
self.input_registry = input_registry
self.input_processor = input_registry.create_input_processor(
model_config)
self.model_executor = executor_class(
model_config=model_config,
......@@ -253,14 +256,12 @@ class LLMEngine:
scheduler_config=scheduler_config,
device_config=device_config,
lora_config=lora_config,
multimodal_config=multimodal_config,
speculative_config=speculative_config,
load_config=load_config,
prompt_adapter_config=prompt_adapter_config,
observability_config=self.observability_config,
)
init_success = False
try:
if not self.model_config.embedding_mode:
self._initialize_kv_caches()
......@@ -320,6 +321,13 @@ class LLMEngine:
if stat_loggers is not None:
self.stat_loggers = stat_loggers
else:
# Lazy import for prometheus multiprocessing.
# We need to set PROMETHEUS_MULTIPROC_DIR environment variable
# before prometheus_client is imported.
# See https://prometheus.github.io/client_python/multiprocess/
from vllm.engine.metrics import (LoggingStatLogger,
PrometheusStatLogger)
self.stat_loggers = {
"logging":
LoggingStatLogger(
......@@ -339,11 +347,6 @@ class LLMEngine:
"vllm.llm_engine",
self.observability_config.otlp_traces_endpoint)
def get_tokenizer_for_seq(self,
sequence: Sequence) -> "PreTrainedTokenizer":
return self.get_tokenizer_group().get_lora_tokenizer(
sequence.lora_request)
# Create sequence output processor, e.g. for beam search or
# speculative decoding.
self.output_processor = (
......@@ -358,13 +361,6 @@ class LLMEngine:
get_tokenizer_for_seq,
),
))
init_success = True
finally:
if not init_success:
# Ensure that model_executor is shut down if LLMEngine init
# failed
self.model_executor.shutdown()
def _initialize_kv_caches(self) -> None:
"""Initialize the KV cache in the worker(s).
......@@ -482,11 +478,20 @@ class LLMEngine:
def get_tokenizer_group(
self,
fail_msg: str = MISSING_TOKENIZER_GROUP_MSG) -> BaseTokenizerGroup:
if self.tokenizer is None:
raise ValueError(fail_msg)
group_type: Type[_G] = BaseTokenizerGroup,
*,
missing_msg: str = MISSING_TOKENIZER_GROUP_MSG,
) -> _G:
tokenizer_group = self.tokenizer
return self.tokenizer
if tokenizer_group is None:
raise ValueError(missing_msg)
if not isinstance(tokenizer_group, group_type):
raise TypeError("Invalid type of tokenizer group. "
f"Expected type: {group_type}, but "
f"found type: {type(tokenizer_group)}")
return tokenizer_group
def get_tokenizer(
self,
......@@ -494,10 +499,6 @@ class LLMEngine:
) -> AnyTokenizer:
return self.get_tokenizer_group().get_lora_tokenizer(lora_request)
# def get_tokenizer_for_seq(self, sequence: Sequence) -> AnyTokenizer:
# return self.get_tokenizer_group().get_lora_tokenizer(
# sequence.lora_request)
def _init_tokenizer(self) -> BaseTokenizerGroup:
return init_tokenizer_from_configs(
model_config=self.model_config,
......@@ -516,8 +517,19 @@ class LLMEngine:
self.prompt_adapter_config.verify_with_model_config(
self.model_config)
def _get_eos_token_id(
self, lora_request: Optional[LoRARequest]) -> Optional[int]:
def _get_bos_token_id(self,
lora_request: Optional[LoRARequest] = None
) -> Optional[int]:
if self.tokenizer is None:
logger.warning("Using None for BOS token id because tokenizer "
"is not initialized")
return None
return self.tokenizer.get_lora_tokenizer(lora_request).bos_token_id
def _get_eos_token_id(self,
lora_request: Optional[LoRARequest] = None
) -> Optional[int]:
if self.tokenizer is None:
logger.warning("Using None for EOS token id because tokenizer "
"is not initialized")
......@@ -525,16 +537,43 @@ class LLMEngine:
return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id
def _get_decoder_start_token_id(self) -> Optional[int]:
'''
Obtain the decoder start token id employed by an encoder/decoder
model. Returns None for non-encoder/decoder models or if the
model config is unavailable.
'''
if not self.is_encoder_decoder_model():
logger.warning("Using None for decoder start token id because "
"this is not an encoder/decoder model.")
return None
if (self.model_config is None or self.model_config.hf_config is None):
logger.warning("Using None for decoder start token id because "
"model config is not available.")
return None
dec_start_token_id = getattr(self.model_config.hf_config,
'decoder_start_token_id', None)
if dec_start_token_id is None:
logger.warning("Falling back on <BOS> for decoder start token id "
"because decoder start token id is not available.")
dec_start_token_id = self._get_bos_token_id()
return dec_start_token_id
def _add_processed_request(
self,
request_id: str,
processed_inputs: LLMInputs,
processed_inputs: Union[LLMInputs, EncoderDecoderLLMInputs],
params: Union[SamplingParams, PoolingParams],
arrival_time: float,
lora_request: Optional[LoRARequest],
prompt_adapter_request: Optional[PromptAdapterRequest],
trace_headers: Optional[Mapping[str, str]] = None,
) -> None:
self._validate_model_inputs(processed_inputs)
# Create the sequences.
block_size = self.cache_config.block_size
seq_id = next(self.seq_counter)
......@@ -543,6 +582,16 @@ class LLMEngine:
seq = Sequence(seq_id, processed_inputs, block_size, eos_token_id,
lora_request, prompt_adapter_request)
encoder_seq = None
if 'encoder_prompt_token_ids' in processed_inputs:
encoder_seq = Sequence(seq_id,
processed_inputs,
block_size,
eos_token_id,
lora_request,
prompt_adapter_request,
from_decoder_prompt=False)
# Create a SequenceGroup based on SamplingParams or PoolingParams
if isinstance(params, SamplingParams):
seq_group = self._create_sequence_group_with_sampling(
......@@ -552,7 +601,8 @@ class LLMEngine:
arrival_time=arrival_time,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request)
prompt_adapter_request=prompt_adapter_request,
encoder_seq=encoder_seq)
elif isinstance(params, PoolingParams):
seq_group = self._create_sequence_group_with_pooling(
request_id,
......@@ -560,7 +610,8 @@ class LLMEngine:
params,
arrival_time=arrival_time,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
prompt_adapter_request=prompt_adapter_request,
encoder_seq=encoder_seq)
else:
raise ValueError(
"Either SamplingParams or PoolingParams must be provided.")
......@@ -576,36 +627,333 @@ class LLMEngine:
def stop_remote_worker_execution_loop(self) -> None:
self.model_executor.stop_remote_worker_execution_loop()
def process_model_inputs(
_LLMInputComponentsType = Tuple[str, List[int]]
def _prepare_decoder_input_ids_for_generation(
self,
decoder_input_ids: Optional[List[int]],
) -> List[int]:
"""
Prepares `decoder_input_ids` for generation with encoder-decoder models.
Based on
https://github.com/huggingface/transformers/blob/
4037a2b5b1278736e566aec12e169100275545ea/
src/transformers/generation/utils.py
specifically GenerationMixin._prepare_decoder_input_ids_for_generation()
Arguments:
* decoder_input_ids: input token ids to preprocess
Returns:
* Processed token list
"""
decoder_start_token_id = self._get_decoder_start_token_id()
assert decoder_start_token_id is not None
if decoder_input_ids is None:
# no decoder prompt input ->
# use decoder_start_token_id as decoder_input_ids
decoder_input_ids = self._get_default_enc_dec_decoder_prompt()
if (len(decoder_input_ids) == 0
or decoder_input_ids[0] != decoder_start_token_id):
decoder_input_ids = [decoder_start_token_id] + decoder_input_ids
return decoder_input_ids
def _tokenize_prompt(
self,
prompt: str,
request_id: str,
lora_request: Optional[LoRARequest],
) -> List[int]:
'''
Wrapper around application of the model's tokenizer.
Arguments:
* prompt
* request_id
* lora_request
Returns:
* prompt token ids
'''
tokenizer = self.get_tokenizer_group(
missing_msg="prompts must be None if skip_tokenizer_init is True")
return tokenizer.encode(request_id=request_id,
prompt=prompt,
lora_request=lora_request)
def _extract_prompt_components(
self,
inputs: SingletonPromptInputs,
request_id: str,
lora_request: Optional[LoRARequest] = None,
) -> PromptComponents:
'''
Extract the components of any single encoder or decoder input prompt.
Arguments:
* request_id
* inputs: single encoder or decoder input prompt
* lora_request: this is only valid for decoder prompts
Returns:
* prompt
* prompt_token_ids
* multi_modal_data
'''
if isinstance(inputs, str):
prompt = inputs
prompt_token_ids = self._tokenize_prompt(
prompt,
request_id=request_id,
lora_request=lora_request,
)
multi_modal_data = None
elif isinstance(inputs, dict):
if "prompt_token_ids" in inputs:
prompt = None
prompt_token_ids = inputs["prompt_token_ids"]
else:
# NOTE: This extra assignment is required to pass mypy
prompt = parsed_prompt = inputs["prompt"]
prompt_token_ids = self._tokenize_prompt(
parsed_prompt,
request_id=request_id,
lora_request=lora_request,
)
multi_modal_data = inputs.get("multi_modal_data")
else:
assert_never(inputs)
return prompt, prompt_token_ids, multi_modal_data
def _apply_prompt_adapter(
self,
prompt_token_ids: List[int],
prompt_adapter_request: Optional[PromptAdapterRequest],
) -> List[int]:
if prompt_adapter_request:
prompt_token_ids = (
[0] * prompt_adapter_request.prompt_adapter_num_virtual_tokens
+ prompt_token_ids)
return prompt_token_ids
def _get_default_enc_dec_decoder_prompt(self) -> List[int]:
'''
Specifically for encoder/decoder models:
generate a default decoder prompt for when
the user specifies only the encoder prompt.
Encoder/decoder models utilize the decoder
prompt in different ways; as new models are
added, it is intended that this function
will be extended to produce differing
default decoder prompts, depending on the
model variety.
Absent a special case, the default behavior
of this method is to mirror the behavior of
the HuggingFace (HF) GenerationMixin for a None
decoder prompt, which is to employ a logit processor
setting to force the first decoded token to be <BOS>.
Here, this behavior is approximated by having the
"default" decoder prompt be <BOS>.
However, it is possible that in the future
other models may have different or more
complex logic for the default decoder prompt.
This motivates having a special helper method
for default decoder prompts.
Returns:
* prompt_token_ids
'''
bos_token_id = self._get_bos_token_id()
assert bos_token_id is not None
return [bos_token_id]
def _build_enc_dec_llm_inputs(
self,
encoder_comps: PromptComponents,
decoder_comps: DecoderPromptComponents,
) -> EncoderDecoderLLMInputs:
encoder_prompt, encoder_prompt_ids, encoder_mm_data = encoder_comps
decoder_prompt, decoder_prompt_ids, decoder_mm_data = decoder_comps
if encoder_mm_data is not None or decoder_mm_data is not None:
raise ValueError("Multi-modal encoder-decoder models are "
"not supported yet")
decoder_prompt_ids = (
self._prepare_decoder_input_ids_for_generation(decoder_prompt_ids))
return EncoderDecoderLLMInputs(
prompt_token_ids=decoder_prompt_ids,
prompt=decoder_prompt,
encoder_prompt_token_ids=encoder_prompt_ids,
encoder_prompt=encoder_prompt,
)
def _process_encoder_decoder_prompt(
self,
inputs: PromptInputs,
request_id: str,
) -> EncoderDecoderLLMInputs:
'''
For encoder/decoder models only:
Process an input prompt into an
:class:`EncoderDecoderLLMInputs` instance.
There are two types of input prompts:
singleton prompts which carry only the
encoder prompt, and explicit encoder/decoder
prompts which carry both the encoder and the
decoder prompts as member variables.
This function handles the following scenarios:
* Singleton encoder prompt: extract encoder prompt
token ids & infer default decoder prompt token ids
* Explicit encoder/decoder prompt: extract encoder
and decoder prompt token ids
Note that for Explicit encoder/decoder prompts,
each sub-prompt (encoder or decoder prompt) can
have any possible singleton type; thus this
method relies on helper functions to obtain
token ids for the sub-prompts.
Arguments:
* inputs: an input prompt
* request_id
Returns:
* :class:`EncoderDecoderLLMInputs` instance
'''
encoder_comps: PromptComponents
decoder_comps: DecoderPromptComponents
if is_explicit_encoder_decoder_prompt(inputs):
encoder_comps = self._extract_prompt_components(
inputs["encoder_prompt"],
request_id=request_id,
)
if (decoder_input := inputs["decoder_prompt"]) is None:
decoder_comps = None, None, None
else:
decoder_comps = self._extract_prompt_components(
decoder_input,
request_id=request_id,
)
else:
encoder_comps = self._extract_prompt_components(
inputs,
request_id=request_id,
)
decoder_comps = None, None, None
return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps)
def _build_decoder_only_llm_inputs(
self,
prompt_comps: PromptComponents,
prompt_adapter_request: Optional[PromptAdapterRequest],
) -> LLMInputs:
prompt, prompt_token_ids, multi_modal_data = prompt_comps
prompt_token_ids = self._apply_prompt_adapter(
prompt_token_ids, prompt_adapter_request=prompt_adapter_request)
return LLMInputs(prompt_token_ids=prompt_token_ids,
prompt=prompt,
multi_modal_data=multi_modal_data)
def _process_decoder_only_prompt(
self,
inputs: SingletonPromptInputs,
request_id: str,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> LLMInputs:
if isinstance(inputs, str):
inputs = {"prompt": inputs}
'''
For decoder-only models:
Process an input prompt into an :class:`LLMInputs` instance.
if "prompt_token_ids" not in inputs:
tokenizer = self.get_tokenizer_group("prompts must be None if "
"skip_tokenizer_init is True")
Arguments:
prompt_token_ids = tokenizer.encode(request_id=request_id,
prompt=inputs["prompt"],
lora_request=lora_request)
else:
prompt_token_ids = inputs["prompt_token_ids"]
* inputs: input prompt
* request_id
* lora_request
* prompt_adapter_request
if prompt_adapter_request:
prompt_token_ids = \
[0] * prompt_adapter_request.prompt_adapter_num_virtual_tokens\
+ prompt_token_ids
Returns:
llm_inputs = LLMInputs(prompt_token_ids=prompt_token_ids,
prompt=inputs.get("prompt"),
multi_modal_data=inputs.get("multi_modal_data"))
* :class:`LLMInputs` instance
'''
return self.input_processor(llm_inputs)
prompt_comps = self._extract_prompt_components(
inputs,
request_id=request_id,
lora_request=lora_request,
)
return self._build_decoder_only_llm_inputs(
prompt_comps,
prompt_adapter_request=prompt_adapter_request,
)
def process_model_inputs(
self,
inputs: PromptInputs,
request_id: str,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> Union[LLMInputs, EncoderDecoderLLMInputs]:
if self.is_encoder_decoder_model():
# Encoder-decoder model requires special mapping of
# input prompts to encoder & decoder
model_inputs = self._process_encoder_decoder_prompt(
inputs,
request_id=request_id,
)
else:
if is_explicit_encoder_decoder_prompt(inputs):
raise ValueError("Cannot pass encoder-decoder prompt "
"to decoder-only models")
# Decoder-only operation
model_inputs = self._process_decoder_only_prompt(
inputs,
request_id=request_id,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
)
return self.input_processor(model_inputs)
def add_request(
self,
......@@ -666,10 +1014,11 @@ class LLMEngine:
arrival_time = time.time()
processed_inputs = self.process_model_inputs(
inputs,
request_id=request_id,
inputs=inputs,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
prompt_adapter_request=prompt_adapter_request,
)
self._add_processed_request(
request_id=request_id,
......@@ -690,6 +1039,7 @@ class LLMEngine:
lora_request: Optional[LoRARequest],
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
encoder_seq: Optional[Sequence] = None,
) -> SequenceGroup:
"""Creates a SequenceGroup with SamplingParams."""
max_logprobs = self.get_model_config().max_logprobs
......@@ -715,7 +1065,8 @@ class LLMEngine:
sampling_params=sampling_params,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request)
prompt_adapter_request=prompt_adapter_request,
encoder_seq=encoder_seq)
return seq_group
......@@ -727,6 +1078,7 @@ class LLMEngine:
arrival_time: float,
lora_request: Optional[LoRARequest],
prompt_adapter_request: Optional[PromptAdapterRequest],
encoder_seq: Optional[Sequence] = None,
) -> SequenceGroup:
"""Creates a SequenceGroup with PoolingParams."""
# Defensive copy of PoolingParams, which are used by the pooler
......@@ -738,7 +1090,8 @@ class LLMEngine:
arrival_time=arrival_time,
lora_request=lora_request,
pooling_params=pooling_params,
prompt_adapter_request=prompt_adapter_request)
prompt_adapter_request=prompt_adapter_request,
encoder_seq=encoder_seq)
return seq_group
def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
......@@ -836,6 +1189,22 @@ class LLMEngine:
seq_group = scheduled_seq_group.seq_group
seq_group.update_num_computed_tokens(
scheduled_seq_group.token_chunk_size)
if output is not None and len(output) > 0:
for o in output:
if (isinstance(o, SamplerOutput)
and seq_group.metrics is not None):
if seq_group.metrics.model_forward_time is not None:
seq_group.metrics.model_forward_time += (
o.model_forward_time)
else:
seq_group.metrics.model_forward_time = (
o.model_forward_time)
if seq_group.metrics.model_execute_time is not None:
seq_group.metrics.model_execute_time += (
o.model_execute_time)
else:
seq_group.metrics.model_execute_time = (
o.model_execute_time)
if self.model_config.embedding_mode:
self._process_sequence_group_outputs(seq_group, outputs)
continue
......@@ -916,6 +1285,11 @@ class LLMEngine:
raise NotImplementedError(
"Pipeline parallelism is only supported through AsyncLLMEngine "
"as performance will be severely degraded otherwise.")
if self.scheduler_config.num_scheduler_steps > 1:
raise NotImplementedError(
"Multiple scheduler steps (multi-step) are only supported "
"through AsyncLLMEngine. ")
seq_group_metadata_list, scheduler_outputs = self.scheduler[
0].schedule()
......@@ -1015,6 +1389,13 @@ class LLMEngine:
for scheduler in self.scheduler)
cpu_cache_usage_sys = 1.0 - (num_free_cpu / num_total_cpu)
# Prefix Cache Hit Rate. Note that we always use
# the cache hit rate of the first virtual engine.
cpu_prefix_cache_hit_rate = self.scheduler[
0].get_prefix_cache_hit_rate(Device.CPU)
gpu_prefix_cache_hit_rate = self.scheduler[
0].get_prefix_cache_hit_rate(Device.GPU)
# Iteration stats
num_prompt_tokens_iter = 0
num_generation_tokens_iter = 0
......@@ -1123,6 +1504,9 @@ class LLMEngine:
# KV Cache Usage in %
gpu_cache_usage_sys=gpu_cache_usage_sys,
cpu_cache_usage_sys=cpu_cache_usage_sys,
# Prefix Cache Hit Rate
cpu_prefix_cache_hit_rate=cpu_prefix_cache_hit_rate,
gpu_prefix_cache_hit_rate=gpu_prefix_cache_hit_rate,
# Iteration stats
num_prompt_tokens_iter=num_prompt_tokens_iter,
......@@ -1228,3 +1612,28 @@ class LLMEngine:
seq_span.set_attribute(
SpanAttributes.LLM_LATENCY_TIME_TO_FIRST_TOKEN, ttft)
seq_span.set_attribute(SpanAttributes.LLM_LATENCY_E2E, e2e_time)
if metrics.scheduler_time is not None:
seq_span.set_attribute(
SpanAttributes.LLM_LATENCY_TIME_IN_SCHEDULER,
metrics.scheduler_time)
if metrics.model_forward_time is not None:
seq_span.set_attribute(
SpanAttributes.LLM_LATENCY_TIME_IN_MODEL_FORWARD,
metrics.model_forward_time / 1000.0)
if metrics.model_execute_time is not None:
seq_span.set_attribute(
SpanAttributes.LLM_LATENCY_TIME_IN_MODEL_EXECUTE,
metrics.model_execute_time)
def is_encoder_decoder_model(self):
return self.model_config.is_encoder_decoder_model
def is_embedding_model(self):
return self.model_config.is_embedding_model
def _validate_model_inputs(self, inputs: Union[LLMInputs,
EncoderDecoderLLMInputs]):
prompt_key = "encoder_prompt_token_ids" \
if self.is_encoder_decoder_model() else "prompt_token_ids"
if not inputs.get(prompt_key):
raise ValueError("Prompt cannot be empty")
\ No newline at end of file
import time
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING
from typing import Counter as CollectionsCounter
from typing import Dict, List, Optional, Protocol, Union
from typing import Dict, List, Optional, Union
import numpy as np
import prometheus_client
from vllm.engine.metrics_types import (StatLoggerBase, Stats,
SupportsMetricsInfo)
from vllm.executor.ray_utils import ray
from vllm.logger import init_logger
......@@ -29,41 +28,60 @@ prometheus_client.disable_created_metrics()
# begin-metrics-definitions
class Metrics:
"""
vLLM uses a multiprocessing-based frontend for the OpenAI server.
This means that we need to run prometheus_client in multiprocessing mode
See https://prometheus.github.io/client_python/multiprocess/ for more
details on limitations.
"""
labelname_finish_reason = "finished_reason"
_gauge_cls = prometheus_client.Gauge
_counter_cls = prometheus_client.Counter
_histogram_cls = prometheus_client.Histogram
def __init__(self, labelnames: List[str], max_model_len: int):
# Unregister any existing vLLM collectors
# Unregister any existing vLLM collectors (for CI/CD)
self._unregister_vllm_metrics()
# Config Information
self._create_info_cache_config()
# System stats
# Scheduler State
self.gauge_scheduler_running = self._gauge_cls(
name="vllm:num_requests_running",
documentation="Number of requests currently running on GPU.",
labelnames=labelnames)
labelnames=labelnames,
multiprocess_mode="sum")
self.gauge_scheduler_waiting = self._gauge_cls(
name="vllm:num_requests_waiting",
documentation="Number of requests waiting to be processed.",
labelnames=labelnames)
labelnames=labelnames,
multiprocess_mode="sum")
self.gauge_scheduler_swapped = self._gauge_cls(
name="vllm:num_requests_swapped",
documentation="Number of requests swapped to CPU.",
labelnames=labelnames)
labelnames=labelnames,
multiprocess_mode="sum")
# KV Cache Usage in %
self.gauge_gpu_cache_usage = self._gauge_cls(
name="vllm:gpu_cache_usage_perc",
documentation="GPU KV-cache usage. 1 means 100 percent usage.",
labelnames=labelnames)
labelnames=labelnames,
multiprocess_mode="sum")
self.gauge_cpu_cache_usage = self._gauge_cls(
name="vllm:cpu_cache_usage_perc",
documentation="CPU KV-cache usage. 1 means 100 percent usage.",
labelnames=labelnames)
labelnames=labelnames,
multiprocess_mode="sum")
# Prefix caching block hit rate
self.gauge_cpu_prefix_cache_hit_rate = self._gauge_cls(
name="vllm:cpu_prefix_cache_hit_rate",
documentation="CPU prefix cache block hit rate.",
labelnames=labelnames,
multiprocess_mode="sum")
self.gauge_gpu_prefix_cache_hit_rate = self._gauge_cls(
name="vllm:gpu_prefix_cache_hit_rate",
documentation="GPU prefix cache block hit rate.",
labelnames=labelnames,
multiprocess_mode="sum")
# Iteration stats
self.counter_num_preemption = self._counter_cls(
......@@ -137,11 +155,13 @@ class Metrics:
self.gauge_spec_decode_draft_acceptance_rate = self._gauge_cls(
name="vllm:spec_decode_draft_acceptance_rate",
documentation="Speulative token acceptance rate.",
labelnames=labelnames)
labelnames=labelnames,
multiprocess_mode="sum")
self.gauge_spec_decode_efficiency = self._gauge_cls(
name="vllm:spec_decode_efficiency",
documentation="Speculative decoding system efficiency.",
labelnames=labelnames)
labelnames=labelnames,
multiprocess_mode="sum")
self.counter_spec_decode_num_accepted_tokens = (self._counter_cls(
name="vllm:spec_decode_num_accepted_tokens_total",
documentation="Number of accepted tokens.",
......@@ -160,19 +180,18 @@ class Metrics:
name="vllm:avg_prompt_throughput_toks_per_s",
documentation="Average prefill throughput in tokens/s.",
labelnames=labelnames,
multiprocess_mode="sum",
)
# Deprecated in favor of vllm:generation_tokens_total
self.gauge_avg_generation_throughput = self._gauge_cls(
name="vllm:avg_generation_throughput_toks_per_s",
documentation="Average generation throughput in tokens/s.",
labelnames=labelnames,
multiprocess_mode="sum",
)
def _create_info_cache_config(self) -> None:
# Config Information
self.info_cache_config = prometheus_client.Info(
name='vllm:cache_config',
documentation='information of cache_config')
# end-metrics-definitions
def _unregister_vllm_metrics(self) -> None:
for collector in list(prometheus_client.REGISTRY._collector_to_names):
......@@ -180,9 +199,6 @@ class Metrics:
prometheus_client.REGISTRY.unregister(collector)
# end-metrics-definitions
class _RayGaugeWrapper:
"""Wraps around ray.util.metrics.Gauge to provide same API as
prometheus_client.Gauge"""
......@@ -190,7 +206,9 @@ class _RayGaugeWrapper:
def __init__(self,
name: str,
documentation: str = "",
labelnames: Optional[List[str]] = None):
labelnames: Optional[List[str]] = None,
multiprocess_mode: str = ""):
del multiprocess_mode
labelnames_tuple = tuple(labelnames) if labelnames else None
self._gauge = ray_metrics.Gauge(name=name,
description=documentation,
......@@ -268,10 +286,6 @@ class RayMetrics(Metrics):
# No-op on purpose
pass
def _create_info_cache_config(self) -> None:
# No-op on purpose
pass
def build_1_2_5_buckets(max_value: int) -> List[int]:
"""
......@@ -295,46 +309,6 @@ def build_1_2_5_buckets(max_value: int) -> List[int]:
exponent += 1
@dataclass
class Stats:
"""Created by LLMEngine for use by StatLogger."""
now: float
# System stats (should have _sys suffix)
# Scheduler State
num_running_sys: int
num_waiting_sys: int
num_swapped_sys: int
# KV Cache Usage in %
gpu_cache_usage_sys: float
cpu_cache_usage_sys: float
# Iteration stats (should have _iter suffix)
num_prompt_tokens_iter: int
num_generation_tokens_iter: int
time_to_first_tokens_iter: List[float]
time_per_output_tokens_iter: List[float]
num_preemption_iter: int
# Request stats (should have _requests suffix)
# Latency
time_e2e_requests: List[float]
# Metadata
num_prompt_tokens_requests: List[int]
num_generation_tokens_requests: List[int]
best_of_requests: List[int]
n_requests: List[int]
finished_reason_requests: List[str]
spec_decode_metrics: Optional["SpecDecodeWorkerMetrics"] = None
class SupportsMetricsInfo(Protocol):
def metrics_info(self) -> Dict[str, str]:
...
def local_interval_elapsed(now: float, last_log: float,
local_interval: float) -> bool:
elapsed_time = now - last_log
......@@ -346,38 +320,9 @@ def get_throughput(tracked_stats: List[int], now: float,
return float(np.sum(tracked_stats) / (now - last_log))
class StatLoggerBase(ABC):
"""Base class for StatLogger."""
def __init__(self, local_interval: float) -> None:
# Tracked stats over current local logging interval.
self.num_prompt_tokens: List[int] = []
self.num_generation_tokens: List[int] = []
self.last_local_log = time.time()
self.local_interval = local_interval
self.spec_decode_metrics: Optional["SpecDecodeWorkerMetrics"] = None
@abstractmethod
def info(self, type: str, obj: SupportsMetricsInfo) -> None:
raise NotImplementedError
@abstractmethod
def log(self, stats: Stats) -> None:
raise NotImplementedError
def maybe_update_spec_decode_metrics(self, stats: Stats):
"""Save spec decode metrics (since they are unlikely
to be emitted at same time as log interval)."""
if stats.spec_decode_metrics is not None:
self.spec_decode_metrics = stats.spec_decode_metrics
class LoggingStatLogger(StatLoggerBase):
"""LoggingStatLogger is used in LLMEngine to log to Stdout."""
def info(self, type: str, obj: SupportsMetricsInfo) -> None:
raise NotImplementedError
def log(self, stats: Stats) -> None:
"""Called by LLMEngine.
Logs to Stdout every self.local_interval seconds."""
......@@ -417,7 +362,13 @@ class LoggingStatLogger(StatLoggerBase):
stats.gpu_cache_usage_sys * 100,
stats.cpu_cache_usage_sys * 100,
)
if (stats.cpu_prefix_cache_hit_rate >= 0
or stats.gpu_prefix_cache_hit_rate >= 0):
logger.info(
"Prefix cache hit rate: GPU: %.2f%%, CPU: %.2f%%",
stats.gpu_prefix_cache_hit_rate * 100,
stats.cpu_prefix_cache_hit_rate * 100,
)
if self.spec_decode_metrics is not None:
logger.info(
self._format_spec_decode_metrics_str(
......@@ -440,10 +391,14 @@ class LoggingStatLogger(StatLoggerBase):
f"Number of draft tokens: {metrics.draft_tokens}, "
f"Number of emitted tokens: {metrics.emitted_tokens}.")
def info(self, type: str, obj: SupportsMetricsInfo) -> None:
raise NotImplementedError
class PrometheusStatLogger(StatLoggerBase):
"""PrometheusStatLogger is used LLMEngine to log to Promethus."""
_metrics_cls = Metrics
_gauge_cls = prometheus_client.Gauge
def __init__(self, local_interval: float, labels: Dict[str, str],
max_model_len: int) -> None:
......@@ -453,10 +408,6 @@ class PrometheusStatLogger(StatLoggerBase):
self.metrics = self._metrics_cls(labelnames=list(labels.keys()),
max_model_len=max_model_len)
def info(self, type: str, obj: SupportsMetricsInfo) -> None:
if type == "cache_config":
self.metrics.info_cache_config.info(obj.metrics_info())
def _log_gauge(self, gauge, data: Union[int, float]) -> None:
# Convenience function for logging to gauge.
gauge.labels(**self.labels).set(data)
......@@ -489,6 +440,10 @@ class PrometheusStatLogger(StatLoggerBase):
stats.gpu_cache_usage_sys)
self._log_gauge(self.metrics.gauge_cpu_cache_usage,
stats.cpu_cache_usage_sys)
self._log_gauge(self.metrics.gauge_cpu_prefix_cache_hit_rate,
stats.cpu_prefix_cache_hit_rate)
self._log_gauge(self.metrics.gauge_gpu_prefix_cache_hit_rate,
stats.gpu_prefix_cache_hit_rate)
# Iteration level data
self._log_counter(self.metrics.counter_num_preemption,
......@@ -586,6 +541,19 @@ class PrometheusStatLogger(StatLoggerBase):
self.last_local_log = stats.now
self.spec_decode_metrics = None
def info(self, type: str, obj: SupportsMetricsInfo) -> None:
# Info type metrics are syntactic sugar for a gauge permanently set to 1
# Since prometheus multiprocessing mode does not support Info, emulate
# info here with a gauge.
if type == "cache_config":
metrics_info = obj.metrics_info()
info_gauge = self._gauge_cls(
name="vllm:cache_config_info",
documentation="Information of the LLMEngine CacheConfig",
labelnames=metrics_info.keys(),
multiprocess_mode="mostrecent")
info_gauge.labels(**metrics_info).set(1)
class RayPrometheusStatLogger(PrometheusStatLogger):
"""RayPrometheusStatLogger uses Ray metrics instead."""
......
"""
These types are defined in this file to avoid importing vllm.engine.metrics
and therefore importing prometheus_client.
This is required due to usage of Prometheus multiprocess mode to enable
metrics after splitting out the uvicorn process from the engine process.
Prometheus multiprocess mode requires setting PROMETHEUS_MULTIPROC_DIR
before prometheus_client is imported. Typically, this is done by setting
the env variable before launch, but since we are a library, we need to
do this in Python code and lazily import prometheus_client.
"""
import time
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Dict, List, Optional, Protocol
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
@dataclass
class Stats:
"""Created by LLMEngine for use by StatLogger."""
now: float
# System stats (should have _sys suffix)
# Scheduler State
num_running_sys: int
num_waiting_sys: int
num_swapped_sys: int
# KV Cache Usage in %
gpu_cache_usage_sys: float
cpu_cache_usage_sys: float
# Prefix caching block hit rate
cpu_prefix_cache_hit_rate: float
gpu_prefix_cache_hit_rate: float
# Iteration stats (should have _iter suffix)
num_prompt_tokens_iter: int
num_generation_tokens_iter: int
time_to_first_tokens_iter: List[float]
time_per_output_tokens_iter: List[float]
num_preemption_iter: int
# Request stats (should have _requests suffix)
# Latency
time_e2e_requests: List[float]
# Metadata
num_prompt_tokens_requests: List[int]
num_generation_tokens_requests: List[int]
best_of_requests: List[int]
n_requests: List[int]
finished_reason_requests: List[str]
spec_decode_metrics: Optional["SpecDecodeWorkerMetrics"] = None
class SupportsMetricsInfo(Protocol):
def metrics_info(self) -> Dict[str, str]:
...
class StatLoggerBase(ABC):
"""Base class for StatLogger."""
def __init__(self, local_interval: float) -> None:
# Tracked stats over current local logging interval.
self.num_prompt_tokens: List[int] = []
self.num_generation_tokens: List[int] = []
self.last_local_log = time.time()
self.local_interval = local_interval
self.spec_decode_metrics: Optional["SpecDecodeWorkerMetrics"] = None
@abstractmethod
def log(self, stats: Stats) -> None:
raise NotImplementedError
@abstractmethod
def info(self, type: str, obj: SupportsMetricsInfo) -> None:
raise NotImplementedError
def maybe_update_spec_decode_metrics(self, stats: Stats):
"""Save spec decode metrics (since they are unlikely
to be emitted at same time as log interval)."""
if stats.spec_decode_metrics is not None:
self.spec_decode_metrics = stats.spec_decode_metrics
from abc import ABC, abstractmethod
from typing import Callable, List
from transformers import PreTrainedTokenizer
from vllm.config import SchedulerConfig
from vllm.core.scheduler import Scheduler
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.sequence import Sequence, SequenceGroup, SequenceGroupOutput
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import Counter
......@@ -29,7 +28,7 @@ class SequenceGroupOutputProcessor(ABC):
detokenizer: Detokenizer,
scheduler: List[Scheduler],
seq_counter: Counter,
get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer],
get_tokenizer_for_seq: Callable[[Sequence], AnyTokenizer],
stop_checker: "StopChecker",
):
"""Create an output processor.
......
import functools
from typing import Callable, List
from transformers import PreTrainedTokenizer
from vllm.core.scheduler import Scheduler
from vllm.engine.output_processor.interfaces import (
SequenceGroupOutputProcessor)
......@@ -12,6 +10,7 @@ from vllm.sampling_params import SamplingParams
from vllm.sequence import (Sequence, SequenceGroup, SequenceGroupOutput,
SequenceOutput, SequenceStatus)
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import Counter
logger = init_logger(__name__)
......@@ -36,7 +35,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
detokenizer: Detokenizer,
scheduler: List[Scheduler],
seq_counter: Counter,
get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer],
get_tokenizer_for_seq: Callable[[Sequence], AnyTokenizer],
stop_checker: StopChecker,
):
self.detokenizer = detokenizer
......
from typing import Callable, Optional
from transformers import PreTrainedTokenizer
from vllm.lora.request import LoRARequest
from vllm.sampling_params import SamplingParams
from vllm.sequence import Sequence, SequenceStatus
from vllm.transformers_utils.tokenizer import AnyTokenizer
class StopChecker:
......@@ -15,8 +14,7 @@ class StopChecker:
"""
def __init__(self, max_model_len: int,
get_tokenizer_for_seq: Callable[[Sequence],
PreTrainedTokenizer]):
get_tokenizer_for_seq: Callable[[Sequence], AnyTokenizer]):
# Do not use it directly, but use `self._get_max_model_len`.
self._max_model_len = max_model_len
self.get_tokenizer_for_seq = get_tokenizer_for_seq
......
from typing import (AsyncIterator, List, Mapping, Optional, Protocol,
from typing import (AsyncGenerator, List, Mapping, Optional, Protocol,
runtime_checkable)
from transformers import PreTrainedTokenizer
from vllm.config import DecodingConfig, ModelConfig
from vllm.core.scheduler import SchedulerOutputs
from vllm.inputs.data import PromptInputs
......@@ -12,6 +10,7 @@ from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.sequence import SamplerOutput
from vllm.transformers_utils.tokenizer import AnyTokenizer
@runtime_checkable
......@@ -30,7 +29,11 @@ class AsyncEngineClient(Protocol):
def errored(self) -> bool:
...
async def generate(
@property
def limit_concurrency(self) -> Optional[int]:
"""Maximum number of concurrently running requests."""
def generate(
self,
inputs: PromptInputs,
sampling_params: SamplingParams,
......@@ -38,18 +41,20 @@ class AsyncEngineClient(Protocol):
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
) -> AsyncIterator[RequestOutput]:
) -> AsyncGenerator[RequestOutput, None]:
"""Generates outputs for a request"""
...
async def encode(
def encode(
self,
inputs: PromptInputs,
pooling_params: PoolingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
) -> AsyncIterator[EmbeddingRequestOutput]:
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
"""Generate outputs for a request from an embedding model."""
...
async def abort(self, request_id: str) -> None:
"""Abort a request.
......@@ -60,25 +65,37 @@ class AsyncEngineClient(Protocol):
async def get_model_config(self) -> ModelConfig:
"""Get the model configuration of the vLLM engine."""
...
async def get_decoding_config(self) -> DecodingConfig:
...
"""Get the decoding configuration of the vLLM engine."""
async def get_tokenizer(
self,
lora_request: Optional[LoRARequest] = None,
) -> PreTrainedTokenizer:
"""Get the appropriate Tokenizer for the request"""
) -> AnyTokenizer:
"""Get the appropriate tokenizer for the request"""
...
async def is_tracing_enabled(self) -> bool:
pass
...
async def do_log_stats(
self,
scheduler_outputs: Optional[SchedulerOutputs] = None,
model_output: Optional[List[SamplerOutput]] = None,
) -> None:
pass
...
async def check_health(self) -> None:
"""Raise if unhealthy"""
...
async def start_profile(self) -> None:
"""Start profiling the engine"""
...
async def stop_profile(self) -> None:
"""Start profiling the engine"""
...
......@@ -20,7 +20,8 @@ from vllm.entrypoints.launcher import serve_http
from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams
from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser, random_uuid
from vllm.utils import (FlexibleArgumentParser, iterate_with_cancellation,
random_uuid)
from vllm.version import __version__ as VLLM_VERSION
logger = init_logger("vllm.entrypoints.api_server")
......@@ -53,11 +54,14 @@ async def generate(request: Request) -> Response:
assert engine is not None
results_generator = engine.generate(prompt, sampling_params, request_id)
results_generator = iterate_with_cancellation(
results_generator, is_cancelled=request.is_disconnected)
# Streaming case
async def stream_results() -> AsyncGenerator[bytes, None]:
async for request_output in results_generator:
prompt = request_output.prompt
assert prompt is not None
text_outputs = [
prompt + output.text for output in request_output.outputs
]
......@@ -69,15 +73,15 @@ async def generate(request: Request) -> Response:
# Non-streaming case
final_output = None
try:
async for request_output in results_generator:
if await request.is_disconnected():
# Abort the request if the client disconnects.
await engine.abort(request_id)
return Response(status_code=499)
final_output = request_output
except asyncio.CancelledError:
return Response(status_code=499)
assert final_output is not None
prompt = final_output.prompt
assert prompt is not None
text_outputs = [prompt + output.text for output in final_output.outputs]
ret = {"text": text_outputs}
return JSONResponse(ret)
......@@ -113,9 +117,11 @@ async def run_server(args: Namespace,
logger.info("args: %s", args)
app = await init_app(args, llm_engine)
assert engine is not None
shutdown_task = await serve_http(
app,
engine=engine,
host=args.host,
port=args.port,
log_level=args.log_level,
......
import codecs
from dataclasses import dataclass
from functools import lru_cache
from typing import (Awaitable, Iterable, List, Optional, Tuple, Union, cast,
final)
from pathlib import Path
from typing import (Any, Awaitable, Iterable, List, Literal, Optional, Tuple,
Union)
# yapf conflicts with isort for this block
# yapf: disable
......@@ -14,18 +15,33 @@ from openai.types.chat import (
ChatCompletionMessageParam as OpenAIChatCompletionMessageParam)
# yapf: enable
# pydantic needs the TypedDict from typing_extensions
from pydantic import ConfigDict
from transformers import PreTrainedTokenizer
from typing_extensions import Required, TypedDict
from pydantic import ConfigDict, TypeAdapter
from typing_extensions import Required, TypeAlias, TypedDict
from vllm.config import ModelConfig
from vllm.logger import init_logger
from vllm.multimodal import MultiModalDataDict
from vllm.multimodal.utils import async_get_and_parse_image
from vllm.multimodal.utils import (async_get_and_parse_audio,
async_get_and_parse_image)
from vllm.transformers_utils.tokenizer import AnyTokenizer
logger = init_logger(__name__)
class AudioURL(TypedDict, total=False):
url: Required[str]
"""
Either a URL of the audio or a data URL with base64 encoded audio data.
"""
class ChatCompletionContentPartAudioParam(TypedDict, total=False):
audio_url: Required[AudioURL]
type: Required[Literal["audio_url"]]
"""The type of the content part."""
class CustomChatCompletionContentPartParam(TypedDict, total=False):
__pydantic_config__ = ConfigDict(extra="allow") # type: ignore
......@@ -33,8 +49,9 @@ class CustomChatCompletionContentPartParam(TypedDict, total=False):
"""The type of the content part."""
ChatCompletionContentPartParam = Union[OpenAIChatCompletionContentPartParam,
CustomChatCompletionContentPartParam]
ChatCompletionContentPartParam: TypeAlias = Union[
OpenAIChatCompletionContentPartParam, ChatCompletionContentPartAudioParam,
CustomChatCompletionContentPartParam, ]
class CustomChatCompletionMessageParam(TypedDict, total=False):
......@@ -57,7 +74,7 @@ ChatCompletionMessageParam = Union[OpenAIChatCompletionMessageParam,
CustomChatCompletionMessageParam]
@final # So that it should be compatible with Dict[str, str]
# TODO: Make fields ReadOnly once mypy supports it
class ConversationMessage(TypedDict):
role: str
content: str
......@@ -69,13 +86,17 @@ class ChatMessageParseResult:
mm_futures: List[Awaitable[MultiModalDataDict]]
def load_chat_template(chat_template: Optional[str]) -> Optional[str]:
def load_chat_template(
chat_template: Optional[Union[Path, str]]) -> Optional[str]:
if chat_template is None:
return None
try:
with open(chat_template, "r") as f:
resolved_chat_template = f.read()
except OSError as e:
if isinstance(chat_template, Path):
raise
JINJA_CHARS = "{}\n"
if not any(c in chat_template for c in JINJA_CHARS):
msg = (f"The supplied chat template ({chat_template}) "
......@@ -92,11 +113,12 @@ def load_chat_template(chat_template: Optional[str]) -> Optional[str]:
@lru_cache(maxsize=None)
def _image_token_str(model_config: ModelConfig,
tokenizer: PreTrainedTokenizer) -> Optional[str]:
def _mm_token_str(model_config: ModelConfig, tokenizer: AnyTokenizer,
modality: Literal["image", "audio"]) -> Optional[str]:
# TODO: Let user specify how to insert image tokens into prompt
# (similar to chat template)
model_type = model_config.hf_config.model_type
if modality == "image":
if model_type == "phi3_v":
# Workaround since this token is not defined in the tokenizer
return "<|image_1|>"
......@@ -109,40 +131,54 @@ def _image_token_str(model_config: ModelConfig,
return tokenizer.decode(model_config.hf_config.image_token_index)
if model_type in ("chameleon", "internvl_chat"):
return "<image>"
raise TypeError(f"Unknown model type: {model_type}")
elif modality == "audio":
if model_type == "ultravox":
return "<|reserved_special_token_0|>"
raise TypeError(f"Unknown model type: {model_type}")
else:
raise TypeError(f"Unknown modality: {modality}")
# TODO: Let user specify how to insert image tokens into prompt
# TODO: Let user specify how to insert multimodal tokens into prompt
# (similar to chat template)
def _get_full_image_text_prompt(image_token_str: str, text_prompt: str) -> str:
"""Combine image and text prompts for vision language model"""
def _get_full_multimodal_text_prompt(placeholder_token_str: str,
text_prompt: str) -> str:
"""Combine multimodal prompts for a multimodal language model"""
# NOTE: For now we assume all model architectures use the same
# image + text prompt format. This may change in the future.
return f"{image_token_str}\n{text_prompt}"
# placeholder + text prompt format. This may change in the future.
return f"{placeholder_token_str}\n{text_prompt}"
_TextParser = TypeAdapter(ChatCompletionContentPartTextParam)
_ImageParser = TypeAdapter(ChatCompletionContentPartImageParam)
_AudioParser = TypeAdapter(ChatCompletionContentPartAudioParam)
def _parse_chat_message_content_parts(
role: str,
parts: Iterable[ChatCompletionContentPartParam],
model_config: ModelConfig,
tokenizer: PreTrainedTokenizer,
tokenizer: AnyTokenizer,
) -> ChatMessageParseResult:
texts: List[str] = []
mm_futures: List[Awaitable[MultiModalDataDict]] = []
modality: Literal["image", "audio"] = "image"
for part in parts:
part_type = part["type"]
if part_type == "text":
text = cast(ChatCompletionContentPartTextParam, part)["text"]
text = _TextParser.validate_python(part)["text"]
texts.append(text)
elif part_type == "image_url":
modality = "image"
if len(mm_futures) > 0:
raise NotImplementedError(
"Multiple 'image_url' input is currently not supported.")
"Multiple multimodal inputs is currently not supported.")
image_url = cast(ChatCompletionContentPartImageParam,
part)["image_url"]
image_url = _ImageParser.validate_python(part)["image_url"]
if image_url.get("detail", "auto") != "auto":
logger.warning(
......@@ -151,21 +187,31 @@ def _parse_chat_message_content_parts(
image_future = async_get_and_parse_image(image_url["url"])
mm_futures.append(image_future)
elif part_type == "audio_url":
modality = "audio"
if len(mm_futures) > 0:
raise NotImplementedError(
"Multiple multimodal inputs is currently not supported.")
audio_url = _AudioParser.validate_python(part)["audio_url"]
audio_future = async_get_and_parse_audio(audio_url["url"])
mm_futures.append(audio_future)
else:
raise NotImplementedError(f"Unknown part type: {part_type}")
text_prompt = "\n".join(texts)
if mm_futures:
image_token_str = _image_token_str(model_config, tokenizer)
if image_token_str is not None:
if image_token_str in text_prompt:
placeholder_token_str = _mm_token_str(model_config, tokenizer,
modality)
if placeholder_token_str is not None:
if placeholder_token_str in text_prompt:
logger.warning(
"Detected image token string in the text prompt. "
"Detected multi-modal token string in the text prompt. "
"Skipping prompt formatting.")
else:
text_prompt = _get_full_image_text_prompt(
image_token_str=image_token_str,
text_prompt = _get_full_multimodal_text_prompt(
placeholder_token_str=placeholder_token_str,
text_prompt=text_prompt,
)
......@@ -177,7 +223,7 @@ def _parse_chat_message_content_parts(
def _parse_chat_message_content(
message: ChatCompletionMessageParam,
model_config: ModelConfig,
tokenizer: PreTrainedTokenizer,
tokenizer: AnyTokenizer,
) -> ChatMessageParseResult:
role = message["role"]
content = message.get("content")
......@@ -188,14 +234,18 @@ def _parse_chat_message_content(
messages = [ConversationMessage(role=role, content=content)]
return ChatMessageParseResult(messages=messages, mm_futures=[])
return _parse_chat_message_content_parts(role, content, model_config,
tokenizer)
return _parse_chat_message_content_parts(
role,
content, # type: ignore
model_config,
tokenizer,
)
def parse_chat_messages(
messages: List[ChatCompletionMessageParam],
model_config: ModelConfig,
tokenizer: PreTrainedTokenizer,
tokenizer: AnyTokenizer,
) -> Tuple[List[ConversationMessage], List[Awaitable[MultiModalDataDict]]]:
conversation: List[ConversationMessage] = []
mm_futures: List[Awaitable[MultiModalDataDict]] = []
......@@ -208,3 +258,28 @@ def parse_chat_messages(
mm_futures.extend(parse_result.mm_futures)
return conversation, mm_futures
def apply_chat_template(
tokenizer: AnyTokenizer,
conversation: List[ConversationMessage],
chat_template: Optional[str],
*,
tokenize: bool = False, # Different from HF's default
**kwargs: Any,
) -> str:
if chat_template is None and tokenizer.chat_template is None:
raise ValueError(
"As of transformers v4.44, default chat template is no longer "
"allowed, so you must provide a chat template if the tokenizer "
"does not define one.")
prompt = tokenizer.apply_chat_template(
conversation=conversation,
chat_template=chat_template,
tokenize=tokenize,
**kwargs,
)
assert isinstance(prompt, str)
return prompt
import asyncio
import signal
from http import HTTPStatus
from typing import Any
import uvicorn
from fastapi import FastAPI
from fastapi import FastAPI, Response
from vllm import envs
from vllm.engine.async_llm_engine import AsyncEngineDeadError
from vllm.engine.protocol import AsyncEngineClient
from vllm.logger import init_logger
from vllm.utils import find_process_using_port
logger = init_logger(__name__)
async def serve_http(app: FastAPI, **uvicorn_kwargs: Any):
async def serve_http(app: FastAPI, engine: AsyncEngineClient,
**uvicorn_kwargs: Any):
logger.info("Available routes are:")
for route in app.routes:
methods = getattr(route, "methods", None)
......@@ -21,8 +27,18 @@ async def serve_http(app: FastAPI, **uvicorn_kwargs: Any):
logger.info("Route: %s, Methods: %s", path, ', '.join(methods))
# Set concurrency limits in uvicorn if running in multiprocessing mode
# since zmq has maximum socket limit of zmq.constants.SOCKET_LIMIT (65536).
if engine.limit_concurrency is not None:
logger.info(
"Launching Uvicorn with --limit_concurrency %s. To avoid this "
"limit at the expense of performance run with "
"--disable-frontend-multiprocessing", engine.limit_concurrency)
uvicorn_kwargs["limit_concurrency"] = engine.limit_concurrency
config = uvicorn.Config(app, **uvicorn_kwargs)
server = uvicorn.Server(config)
_add_shutdown_handlers(app, server, engine)
loop = asyncio.get_running_loop()
......@@ -42,5 +58,45 @@ async def serve_http(app: FastAPI, **uvicorn_kwargs: Any):
await server_task
return dummy_shutdown()
except asyncio.CancelledError:
port = uvicorn_kwargs["port"]
process = find_process_using_port(port)
if process is not None:
logger.debug(
"port %s is used by process %s launched with command:\n%s",
port, process, " ".join(process.cmdline()))
logger.info("Gracefully stopping http server")
return server.shutdown()
def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server,
engine: AsyncEngineClient) -> None:
"""Adds handlers for fatal errors that should crash the server"""
@app.exception_handler(RuntimeError)
async def runtime_error_handler(_, __):
"""On generic runtime error, check to see if the engine has died.
It probably has, in which case the server will no longer be able to
handle requests. Trigger a graceful shutdown with a SIGTERM."""
if (not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH and engine.errored
and not engine.is_running):
logger.fatal("AsyncLLMEngine has failed, terminating server "
"process")
# See discussions here on shutting down a uvicorn server
# https://github.com/encode/uvicorn/discussions/1103
# In this case we cannot await the server shutdown here because
# this handler must first return to close the connection for
# this request.
server.should_exit = True
return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR)
@app.exception_handler(AsyncEngineDeadError)
async def engine_dead_handler(_, __):
"""Kill the server if the async engine is already dead. It will
not handle any further requests."""
if not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH:
logger.fatal("AsyncLLMEngine is already dead, terminating server "
"process")
server.should_exit = True
return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR)
......@@ -2,12 +2,14 @@ from contextlib import contextmanager
from typing import ClassVar, List, Optional, Sequence, Union, cast, overload
from tqdm import tqdm
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine
from vllm.inputs import (PromptInputs, TextPrompt, TokensPrompt,
parse_and_batch_prompt)
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
apply_chat_template,
parse_chat_messages)
from vllm.inputs import PromptInputs, TextPrompt, TokensPrompt
from vllm.inputs.parse import parse_and_batch_prompt
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.guided_decoding import (
......@@ -17,7 +19,9 @@ from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer import get_cached_tokenizer
from vllm.transformers_utils.tokenizer import (AnyTokenizer,
get_cached_tokenizer)
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from vllm.usage.usage_lib import UsageContext
from vllm.utils import Counter, deprecate_kwargs
......@@ -119,18 +123,31 @@ class LLM:
tokenizer_revision: Optional[str] = None,
seed: int = 0,
gpu_memory_utilization: float = 0.9,
swap_space: int = 4,
swap_space: float = 4,
cpu_offload_gb: float = 0,
enforce_eager: bool = False,
enforce_eager: Optional[bool] = None,
max_context_len_to_capture: Optional[int] = None,
max_seq_len_to_capture: int = 8192,
disable_custom_all_reduce: bool = False,
**kwargs,
) -> None:
'''
LLM constructor.
Note: if enforce_eager is unset (enforce_eager is None)
it defaults to False for decoder-only models and True
for encoder/decoder models, since encoder/decoder models
do not currently support CUDAGraph.
'''
if "disable_log_stats" not in kwargs:
kwargs["disable_log_stats"] = True
removed_vision_keys = ("image_token_id", "image_feature_size",
"image_input_shape", "image_input_type")
removed_vision_keys = (
"image_token_id",
"image_feature_size",
"image_input_shape",
"image_input_type",
)
if any(k in kwargs for k in removed_vision_keys):
raise TypeError(
"There is no need to pass vision-related arguments anymore.")
......@@ -159,22 +176,19 @@ class LLM:
engine_args, usage_context=UsageContext.LLM_CLASS)
self.request_counter = Counter()
def get_tokenizer(
self) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
return self.llm_engine.tokenizer.tokenizer
def get_tokenizer(self) -> AnyTokenizer:
return self.llm_engine.get_tokenizer_group(TokenizerGroup).tokenizer
def set_tokenizer(self, tokenizer: AnyTokenizer) -> None:
tokenizer_group = self.llm_engine.get_tokenizer_group(TokenizerGroup)
def set_tokenizer(
self,
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
) -> None:
# While CachedTokenizer is dynamic, have no choice but
# compare class name. Misjudgment will arise from
# user-defined tokenizer started with 'Cached'
if tokenizer.__class__.__name__.startswith("Cached"):
self.llm_engine.tokenizer.tokenizer = tokenizer
tokenizer_group.tokenizer = tokenizer
else:
self.llm_engine.tokenizer.tokenizer = get_cached_tokenizer(
tokenizer)
tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer)
@overload # LEGACY: single (prompt + optional token ids)
def generate(
......@@ -250,11 +264,12 @@ class LLM:
) -> List[RequestOutput]:
...
@deprecate_kwargs("prompts",
@deprecate_kwargs(
"prompts",
"prompt_token_ids",
is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
additional_message="Please use the 'inputs' parameter "
"instead.")
additional_message="Please use the 'inputs' parameter instead.",
)
def generate(
self,
prompts: Union[Union[PromptInputs, Sequence[PromptInputs]],
......@@ -287,7 +302,7 @@ class LLM:
generation, if any.
Returns:
A list of `RequestOutput` objects containing the
A list of ``RequestOutput`` objects containing the
generated completions in the same order as the input prompts.
Note:
......@@ -297,8 +312,8 @@ class LLM:
"""
if self.llm_engine.model_config.embedding_mode:
raise ValueError(
"LLM.generate() is only supported for generation models "
"(XForCausalLM).")
"LLM.generate() is only supported for (conditional) generation "
"models (XForCausalLM, XForConditionalGeneration).")
if prompt_token_ids is not None:
inputs = self._convert_v1_inputs(
......@@ -330,6 +345,62 @@ class LLM:
outputs = self._run_engine(use_tqdm=use_tqdm)
return LLMEngine.validate_outputs(outputs, RequestOutput)
def chat(
self,
messages: List[ChatCompletionMessageParam],
sampling_params: Optional[Union[SamplingParams,
List[SamplingParams]]] = None,
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
chat_template: Optional[str] = None,
add_generation_prompt: bool = True,
) -> List[RequestOutput]:
"""
Generates responses for chat messages.
Converts the messages to prompts using the tokenizer and calls
the :meth:`generate` method to generate the responses.
Args:
messages: A list of messages to generate responses for. Each
message is a list of dictionaries with 'role' and 'content'
keys.
sampling_params: The sampling parameters for text generation.
If None, we use the default sampling parameters. When it
is a single value, it is applied to every prompt. When it
is a list, the list must have the same length as the
prompts and it is paired one by one with the prompt.
use_tqdm: Whether to use tqdm to display the progress bar.
lora_request: LoRA request to use for generation, if any.
chat_template: The template to use for structuring the chat.
If not provided, the model's default chat template will be used.
add_generation_prompt: If True, adds a generation template
to each message.
Returns:
A list of ``RequestOutput`` objects containing the generated
responses in the same order as the input messages.
"""
tokenizer = self.get_tokenizer()
model_config = self.llm_engine.get_model_config()
conversations, _ = parse_chat_messages(messages, model_config,
tokenizer)
prompts = apply_chat_template(
tokenizer,
conversations,
chat_template=chat_template,
add_generation_prompt=add_generation_prompt)
return self.generate(
prompts,
sampling_params,
use_tqdm=use_tqdm,
lora_request=lora_request,
)
@overload # LEGACY: single (prompt + optional token ids)
def encode(
self,
......@@ -404,11 +475,12 @@ class LLM:
) -> List[EmbeddingRequestOutput]:
...
@deprecate_kwargs("prompts",
@deprecate_kwargs(
"prompts",
"prompt_token_ids",
is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
additional_message="Please use the 'inputs' parameter "
"instead.")
additional_message="Please use the 'inputs' parameter instead.",
)
def encode(
self,
prompts: Union[Union[PromptInputs, Sequence[PromptInputs]],
......@@ -504,6 +576,8 @@ class LLM:
inputs: List[PromptInputs] = []
for i in range(num_requests):
item: PromptInputs
if prompts is not None:
item = TextPrompt(prompt=prompts[i])
elif prompt_token_ids is not None:
......@@ -554,15 +628,15 @@ class LLM:
params[i] if isinstance(params, Sequence) else params,
lora_request=lora_request[i] if isinstance(
lora_request, Sequence) else lora_request,
prompt_adapter_request=prompt_adapter_request)
prompt_adapter_request=prompt_adapter_request,
)
def _add_request(
self,
inputs: PromptInputs,
params: Union[SamplingParams, PoolingParams],
lora_request: Optional[Union[List[LoRARequest],
LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> None:
request_id = str(next(self.request_counter))
self.llm_engine.add_request(
......@@ -570,7 +644,8 @@ class LLM:
inputs,
params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
prompt_adapter_request=prompt_adapter_request,
)
def _add_guided_processor(
self,
......@@ -619,8 +694,8 @@ class LLM:
in_spd = total_in_toks / pbar.format_dict["elapsed"]
total_out_toks += sum(
len(stp.token_ids) for stp in output.outputs)
out_spd = total_out_toks / pbar.format_dict[
"elapsed"]
out_spd = (total_out_toks /
pbar.format_dict["elapsed"])
pbar.postfix = (
f"est. speed input: {in_spd:.2f} toks/s, "
f"output: {out_spd:.2f} toks/s")
......@@ -631,3 +706,9 @@ class LLM:
# This is necessary because some requests may be finished earlier than
# its previous requests.
return sorted(outputs, key=lambda x: int(x.request_id))
def _is_encoder_decoder_model(self):
return self.llm_engine.is_encoder_decoder_model()
def _is_embedding_model(self):
return self.llm_engine.is_embedding_model()
import asyncio
import importlib
import inspect
import multiprocessing
import os
import re
import tempfile
from argparse import Namespace
from contextlib import asynccontextmanager
from http import HTTPStatus
from multiprocessing import Process
from typing import AsyncIterator, Set
from typing import AsyncIterator, Optional, Set
from fastapi import APIRouter, FastAPI, Request
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, Response, StreamingResponse
from prometheus_client import make_asgi_app
from starlette.routing import Mount
from typing_extensions import assert_never
import vllm.envs as envs
from vllm.config import ModelConfig
......@@ -28,14 +30,16 @@ from vllm.entrypoints.openai.cli_args import make_arg_parser
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
ChatCompletionResponse,
CompletionRequest,
CompletionResponse,
DetokenizeRequest,
DetokenizeResponse,
EmbeddingRequest, ErrorResponse,
EmbeddingRequest,
EmbeddingResponse, ErrorResponse,
TokenizeRequest,
TokenizeResponse)
# yapf: enable
from vllm.entrypoints.openai.rpc.client import AsyncEngineRPCClient
from vllm.entrypoints.openai.rpc.server import run_rpc_server
# yapf: enable
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
......@@ -43,7 +47,7 @@ from vllm.entrypoints.openai.serving_tokenization import (
OpenAIServingTokenization)
from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser, get_open_port
from vllm.utils import FlexibleArgumentParser, get_open_zmq_ipc_path
from vllm.version import __version__ as VLLM_VERSION
TIMEOUT_KEEP_ALIVE = 5 # seconds
......@@ -54,19 +58,23 @@ openai_serving_chat: OpenAIServingChat
openai_serving_completion: OpenAIServingCompletion
openai_serving_embedding: OpenAIServingEmbedding
openai_serving_tokenization: OpenAIServingTokenization
prometheus_multiproc_dir: tempfile.TemporaryDirectory
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
logger = init_logger('vllm.entrypoints.openai.api_server')
_running_tasks: Set[asyncio.Task] = set()
def model_is_embedding(model_name: str, trust_remote_code: bool) -> bool:
def model_is_embedding(model_name: str, trust_remote_code: bool,
quantization: str) -> bool:
return ModelConfig(model=model_name,
tokenizer=model_name,
tokenizer_mode="auto",
trust_remote_code=trust_remote_code,
quantization=quantization,
seed=0,
dtype="float16").embedding_mode
dtype="auto").embedding_mode
@asynccontextmanager
......@@ -86,7 +94,16 @@ async def lifespan(app: FastAPI):
@asynccontextmanager
async def build_async_engine_client(args) -> AsyncIterator[AsyncEngineClient]:
async def build_async_engine_client(
args: Namespace) -> AsyncIterator[Optional[AsyncEngineClient]]:
"""
Create AsyncEngineClient, either:
- in-process using the AsyncLLMEngine Directly
- multiprocess using AsyncLLMEngine RPC
Returns the Client or None if the creation failed.
"""
# Context manager to handle async_engine_client lifecycle
# Ensures everything is shutdown and cleaned up on error/exit
global engine_args
......@@ -97,7 +114,8 @@ async def build_async_engine_client(args) -> AsyncIterator[AsyncEngineClient]:
# If manually triggered or embedding model, use AsyncLLMEngine in process.
# TODO: support embedding model via RPC.
if (model_is_embedding(args.model, args.trust_remote_code)
if (model_is_embedding(args.model, args.trust_remote_code,
args.quantization)
or args.disable_frontend_multiprocessing):
async_engine_client = AsyncLLMEngine.from_engine_args(
engine_args, usage_context=UsageContext.OPENAI_API_SERVER)
......@@ -106,37 +124,99 @@ async def build_async_engine_client(args) -> AsyncIterator[AsyncEngineClient]:
# Otherwise, use the multiprocessing AsyncLLMEngine.
else:
# Start RPCServer in separate process (holds the AsyncLLMEngine).
port = get_open_port(envs.VLLM_RPC_PORT)
rpc_server_process = Process(target=run_rpc_server,
args=(engine_args,
UsageContext.OPENAI_API_SERVER,
port))
rpc_server_process.start()
if "PROMETHEUS_MULTIPROC_DIR" not in os.environ:
# Make TemporaryDirectory for prometheus multiprocessing
# Note: global TemporaryDirectory will be automatically
# cleaned up upon exit.
global prometheus_multiproc_dir
prometheus_multiproc_dir = tempfile.TemporaryDirectory()
os.environ[
"PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name
else:
logger.warning(
"Found PROMETHEUS_MULTIPROC_DIR was set by user. "
"This directory must be wiped between vLLM runs or "
"you will find inaccurate metrics. Unset the variable "
"and vLLM will properly handle cleanup.")
# Select random path for IPC.
rpc_path = get_open_zmq_ipc_path()
logger.info("Multiprocessing frontend to use %s for RPC Path.",
rpc_path)
# Build RPCClient, which conforms to AsyncEngineClient Protocol.
async_engine_client = AsyncEngineRPCClient(port)
await async_engine_client.setup()
# NOTE: Actually, this is not true yet. We still need to support
# embedding models via RPC (see TODO above)
rpc_client = AsyncEngineRPCClient(rpc_path)
async_engine_client = rpc_client # type: ignore
# Start RPCServer in separate process (holds the AsyncLLMEngine).
context = multiprocessing.get_context("spawn")
# the current process might have CUDA context,
# so we need to spawn a new process
rpc_server_process = context.Process(
target=run_rpc_server,
args=(engine_args, UsageContext.OPENAI_API_SERVER, rpc_path))
rpc_server_process.start()
logger.info("Started engine process with PID %d",
rpc_server_process.pid)
try:
while True:
try:
await rpc_client.setup()
break
except TimeoutError:
if not rpc_server_process.is_alive():
logger.error(
"RPCServer process died before responding "
"to readiness probe")
yield None
return
yield async_engine_client
finally:
# Ensure rpc server process was terminated
rpc_server_process.terminate()
# Close all open connections to the backend
async_engine_client.close()
rpc_client.close()
# Wait for server process to join
rpc_server_process.join()
# Lazy import for prometheus multiprocessing.
# We need to set PROMETHEUS_MULTIPROC_DIR environment variable
# before prometheus_client is imported.
# See https://prometheus.github.io/client_python/multiprocess/
from prometheus_client import multiprocess
multiprocess.mark_process_dead(rpc_server_process.pid)
router = APIRouter()
def mount_metrics(app: FastAPI):
# Lazy import for prometheus multiprocessing.
# We need to set PROMETHEUS_MULTIPROC_DIR environment variable
# before prometheus_client is imported.
# See https://prometheus.github.io/client_python/multiprocess/
from prometheus_client import (CollectorRegistry, make_asgi_app,
multiprocess)
prometheus_multiproc_dir_path = os.getenv("PROMETHEUS_MULTIPROC_DIR", None)
if prometheus_multiproc_dir_path is not None:
logger.info("vLLM to use %s as PROMETHEUS_MULTIPROC_DIR",
prometheus_multiproc_dir_path)
registry = CollectorRegistry()
multiprocess.MultiProcessCollector(registry)
# Add prometheus asgi middleware to route /metrics requests
metrics_route = Mount("/metrics", make_asgi_app(registry=registry))
else:
# Add prometheus asgi middleware to route /metrics requests
metrics_route = Mount("/metrics", make_asgi_app())
# Workaround for 307 Redirect for /metrics
metrics_route.path_regex = re.compile('^/metrics(?P<path>.*)$')
app.routes.append(metrics_route)
......@@ -155,10 +235,11 @@ async def tokenize(request: TokenizeRequest):
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
else:
assert isinstance(generator, TokenizeResponse)
elif isinstance(generator, TokenizeResponse):
return JSONResponse(content=generator.model_dump())
assert_never(generator)
@router.post("/detokenize")
async def detokenize(request: DetokenizeRequest):
......@@ -166,10 +247,11 @@ async def detokenize(request: DetokenizeRequest):
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
else:
assert isinstance(generator, DetokenizeResponse)
elif isinstance(generator, DetokenizeResponse):
return JSONResponse(content=generator.model_dump())
assert_never(generator)
@router.get("/v1/models")
async def show_available_models():
......@@ -191,13 +273,11 @@ async def create_chat_completion(request: ChatCompletionRequest,
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
if request.stream:
return StreamingResponse(content=generator,
media_type="text/event-stream")
else:
assert isinstance(generator, ChatCompletionResponse)
elif isinstance(generator, ChatCompletionResponse):
return JSONResponse(content=generator.model_dump())
return StreamingResponse(content=generator, media_type="text/event-stream")
@router.post("/v1/completions")
async def create_completion(request: CompletionRequest, raw_request: Request):
......@@ -206,12 +286,11 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
if request.stream:
return StreamingResponse(content=generator,
media_type="text/event-stream")
else:
elif isinstance(generator, CompletionResponse):
return JSONResponse(content=generator.model_dump())
return StreamingResponse(content=generator, media_type="text/event-stream")
@router.post("/v1/embeddings")
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
......@@ -220,9 +299,31 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
else:
elif isinstance(generator, EmbeddingResponse):
return JSONResponse(content=generator.model_dump())
assert_never(generator)
if envs.VLLM_TORCH_PROFILER_DIR:
logger.warning(
"Torch Profiler is enabled in the API server. This should ONLY be "
"used for local development!")
@router.post("/start_profile")
async def start_profile():
logger.info("Starting profiler...")
await async_engine_client.start_profile()
logger.info("Profiler started.")
return Response(status_code=200)
@router.post("/stop_profile")
async def stop_profile():
logger.info("Stopping profiler...")
await async_engine_client.stop_profile()
logger.info("Profiler stopped.")
return Response(status_code=200)
def build_app(args: Namespace) -> FastAPI:
app = FastAPI(lifespan=lifespan)
......@@ -340,10 +441,15 @@ async def run_server(args, **uvicorn_kwargs) -> None:
logger.info("args: %s", args)
async with build_async_engine_client(args) as async_engine_client:
# If None, creation of the client failed and we exit.
if async_engine_client is None:
return
app = await init_app(async_engine_client, args)
shutdown_task = await serve_http(
app,
engine=async_engine_client,
host=args.host,
port=args.port,
log_level=args.uvicorn_log_level,
......
......@@ -7,6 +7,7 @@ purposes.
import argparse
import json
import ssl
from typing import List, Optional, Sequence, Union
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
......@@ -16,8 +17,19 @@ from vllm.utils import FlexibleArgumentParser
class LoRAParserAction(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
lora_list = []
def __call__(
self,
parser: argparse.ArgumentParser,
namespace: argparse.Namespace,
values: Optional[Union[str, Sequence[str]]],
option_string: Optional[str] = None,
):
if values is None:
values = []
if isinstance(values, str):
raise TypeError("Expected values to be a list")
lora_list: List[LoRAModulePath] = []
for item in values:
name, path = item.split('=')
lora_list.append(LoRAModulePath(name, path))
......@@ -26,8 +38,19 @@ class LoRAParserAction(argparse.Action):
class PromptAdapterParserAction(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
adapter_list = []
def __call__(
self,
parser: argparse.ArgumentParser,
namespace: argparse.Namespace,
values: Optional[Union[str, Sequence[str]]],
option_string: Optional[str] = None,
):
if values is None:
values = []
if isinstance(values, str):
raise TypeError("Expected values to be a list")
adapter_list: List[PromptAdapterPath] = []
for item in values:
name, path = item.split('=')
adapter_list.append(PromptAdapterPath(name, path))
......
......@@ -2,9 +2,9 @@ from functools import lru_cache, partial
from typing import Dict, FrozenSet, Iterable, List, Optional, Union
import torch
from transformers import PreTrainedTokenizer
from vllm.sampling_params import LogitsProcessor
from vllm.transformers_utils.tokenizer import AnyTokenizer
class AllowedTokenIdsLogitsProcessor:
......@@ -40,9 +40,11 @@ def _get_allowed_token_ids_logits_processor(
return AllowedTokenIdsLogitsProcessor(allowed_token_ids)
def logit_bias_logits_processor(logit_bias: Dict[str,
float], token_ids: List[int],
logits: torch.Tensor) -> torch.Tensor:
def logit_bias_logits_processor(
logit_bias: Dict[int, float],
token_ids: List[int],
logits: torch.Tensor,
) -> torch.Tensor:
for token_id, bias in logit_bias.items():
logits[token_id] += bias
return logits
......@@ -51,8 +53,9 @@ def logit_bias_logits_processor(logit_bias: Dict[str,
def get_logits_processors(
logit_bias: Optional[Union[Dict[int, float], Dict[str, float]]],
allowed_token_ids: Optional[List[int]],
tokenizer: PreTrainedTokenizer) -> List[LogitsProcessor]:
logits_processors = []
tokenizer: AnyTokenizer,
) -> List[LogitsProcessor]:
logits_processors: List[LogitsProcessor] = []
if logit_bias:
try:
# Convert token_id to integer
......@@ -69,7 +72,7 @@ def get_logits_processors(
# Check if token_id is within the vocab size
for token_id, bias in clamped_logit_bias.items():
if token_id < 0 or token_id >= tokenizer.vocab_size:
raise ValueError("token_id in logit_bias contains "
raise ValueError(f"token_id {token_id} in logit_bias contains "
"out-of-vocab token id")
logits_processors.append(
......
......@@ -6,18 +6,20 @@ from typing import Any, Dict, List, Literal, Optional, Union
import torch
from pydantic import BaseModel, ConfigDict, Field, model_validator
from transformers import PreTrainedTokenizer
from typing_extensions import Annotated
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
from vllm.entrypoints.openai.logits_processors import get_logits_processors
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import LogitsProcessor, SamplingParams
from vllm.sequence import Logprob
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import random_uuid
# torch is mocked during docs generation,
# so we have to provide the values as literals
_MOCK_LONG_INFO = Namespace(min=-9223372036854775808, max=9223372036854775807)
_LONG_INFO: Union["torch.iinfo", Namespace]
try:
from sphinx.ext.autodoc.mock import _MockModule
......@@ -152,6 +154,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
skip_special_tokens: bool = True
spaces_between_special_tokens: bool = True
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
prompt_logprobs: Optional[int] = None
# doc: end-chat-completion-sampling-params
# doc: begin-chat-completion-extra-params
......@@ -190,8 +193,9 @@ class ChatCompletionRequest(OpenAIBaseModel):
default=None,
description=(
"A Jinja template to use for this conversion. "
"If this is not passed, the model's default chat template will be "
"used instead."),
"As of transformers v4.44, default chat template is no longer "
"allowed, so you must provide a chat template if the tokenizer "
"does not define one."),
)
chat_template_kwargs: Optional[Dict[str, Any]] = Field(
default=None,
......@@ -232,13 +236,17 @@ class ChatCompletionRequest(OpenAIBaseModel):
# doc: end-chat-completion-extra-params
def to_sampling_params(
self, tokenizer: PreTrainedTokenizer,
self, tokenizer: AnyTokenizer,
guided_decode_logits_processor: Optional[LogitsProcessor],
default_max_tokens: int) -> SamplingParams:
max_tokens = self.max_tokens
if max_tokens is None:
max_tokens = default_max_tokens
prompt_logprobs = self.prompt_logprobs
if prompt_logprobs is None and self.echo:
prompt_logprobs = self.top_logprobs
# We now allow logprobs being true without top_logrobs.
logits_processors = get_logits_processors(
logit_bias=self.logit_bias,
......@@ -248,7 +256,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
if guided_decode_logits_processor:
logits_processors.append(guided_decode_logits_processor)
return SamplingParams(
return SamplingParams.from_optional(
n=self.n,
best_of=self.best_of,
presence_penalty=self.presence_penalty,
......@@ -262,7 +270,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
stop=self.stop,
stop_token_ids=self.stop_token_ids,
logprobs=self.top_logprobs if self.logprobs else None,
prompt_logprobs=self.top_logprobs if self.echo else None,
prompt_logprobs=prompt_logprobs,
ignore_eos=self.ignore_eos,
max_tokens=max_tokens,
min_tokens=self.min_tokens,
......@@ -276,14 +284,36 @@ class ChatCompletionRequest(OpenAIBaseModel):
truncate_prompt_tokens=self.truncate_prompt_tokens,
)
@model_validator(mode='before')
@model_validator(mode="before")
@classmethod
def validate_stream_options(cls, values):
if (values.get('stream_options') is not None
and not values.get('stream')):
def validate_stream_options(cls, data):
if data.get("stream_options") and not data.get("stream"):
raise ValueError(
"stream_options can only be set if stream is true")
return values
"Stream options can only be defined when `stream=True`.")
return data
@model_validator(mode="before")
@classmethod
def check_logprobs(cls, data):
if (prompt_logprobs := data.get("prompt_logprobs")) is not None:
if data.get("stream") and prompt_logprobs > 0:
raise ValueError(
"`prompt_logprobs` are not available when `stream=True`.")
if prompt_logprobs < 0:
raise ValueError("`prompt_logprobs` must be a positive value.")
if (top_logprobs := data.get("top_logprobs")) is not None:
if top_logprobs < 0:
raise ValueError("`top_logprobs` must be a positive value.")
if not data.get("logprobs"):
raise ValueError(
"when using `top_logprobs`, `logprobs` must be set to true."
)
return data
@model_validator(mode="before")
@classmethod
......@@ -316,19 +346,6 @@ class ChatCompletionRequest(OpenAIBaseModel):
"When using `tool_choice`, `tools` must be set.")
return data
@model_validator(mode="before")
@classmethod
def check_logprobs(cls, data):
if "top_logprobs" in data and data["top_logprobs"] is not None:
if "logprobs" not in data or data["logprobs"] is False:
raise ValueError(
"when using `top_logprobs`, `logprobs` must be set to true."
)
elif data["top_logprobs"] < 0:
raise ValueError(
"`top_logprobs` must be a value a positive value.")
return data
class CompletionRequest(OpenAIBaseModel):
# Ordered by official OpenAI API documentation
......@@ -367,6 +384,7 @@ class CompletionRequest(OpenAIBaseModel):
spaces_between_special_tokens: bool = True
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
allowed_token_ids: Optional[List[int]] = None
prompt_logprobs: Optional[int] = None
# doc: end-completion-sampling-params
# doc: begin-completion-extra-params
......@@ -417,13 +435,17 @@ class CompletionRequest(OpenAIBaseModel):
# doc: end-completion-extra-params
def to_sampling_params(
self, tokenizer: PreTrainedTokenizer,
self, tokenizer: AnyTokenizer,
guided_decode_logits_processor: Optional[LogitsProcessor],
default_max_tokens: int) -> SamplingParams:
max_tokens = self.max_tokens
if max_tokens is None:
max_tokens = default_max_tokens
prompt_logprobs = self.prompt_logprobs
if prompt_logprobs is None and self.echo:
prompt_logprobs = self.logprobs
echo_without_generation = self.echo and self.max_tokens == 0
logits_processors = get_logits_processors(
......@@ -434,7 +456,7 @@ class CompletionRequest(OpenAIBaseModel):
if guided_decode_logits_processor:
logits_processors.append(guided_decode_logits_processor)
return SamplingParams(
return SamplingParams.from_optional(
n=self.n,
best_of=self.best_of,
presence_penalty=self.presence_penalty,
......@@ -453,7 +475,7 @@ class CompletionRequest(OpenAIBaseModel):
min_tokens=self.min_tokens,
use_beam_search=self.use_beam_search,
early_stopping=self.early_stopping,
prompt_logprobs=self.logprobs if self.echo else None,
prompt_logprobs=prompt_logprobs,
skip_special_tokens=self.skip_special_tokens,
spaces_between_special_tokens=self.spaces_between_special_tokens,
include_stop_str_in_output=self.include_stop_str_in_output,
......@@ -479,9 +501,17 @@ class CompletionRequest(OpenAIBaseModel):
@model_validator(mode="before")
@classmethod
def check_logprobs(cls, data):
if "logprobs" in data and data[
"logprobs"] is not None and not data["logprobs"] >= 0:
raise ValueError("if passed, `logprobs` must be a positive value.")
if (prompt_logprobs := data.get("prompt_logprobs")) is not None:
if data.get("stream") and prompt_logprobs > 0:
raise ValueError(
"`prompt_logprobs` are not available when `stream=True`.")
if prompt_logprobs < 0:
raise ValueError("`prompt_logprobs` must be a positive value.")
if (logprobs := data.get("logprobs")) is not None and logprobs < 0:
raise ValueError("`logprobs` must be a positive value.")
return data
@model_validator(mode="before")
......@@ -489,7 +519,8 @@ class CompletionRequest(OpenAIBaseModel):
def validate_stream_options(cls, data):
if data.get("stream_options") and not data.get("stream"):
raise ValueError(
"Stream options can only be defined when stream is true.")
"Stream options can only be defined when `stream=True`.")
return data
......@@ -498,7 +529,7 @@ class EmbeddingRequest(OpenAIBaseModel):
# https://platform.openai.com/docs/api-reference/embeddings
model: str
input: Union[List[int], List[List[int]], str, List[str]]
encoding_format: Optional[str] = Field('float', pattern='^(float|base64)$')
encoding_format: Literal["float", "base64"] = "float"
dimensions: Optional[int] = None
user: Optional[str] = None
......@@ -531,6 +562,7 @@ class CompletionResponseChoice(OpenAIBaseModel):
"to stop, None if the completion finished for some other reason "
"including encountering the EOS token"),
)
prompt_logprobs: Optional[List[Optional[Dict[int, Logprob]]]] = None
class CompletionResponse(OpenAIBaseModel):
......@@ -626,6 +658,7 @@ class ChatCompletionResponse(OpenAIBaseModel):
model: str
choices: List[ChatCompletionResponseChoice]
usage: UsageInfo
prompt_logprobs: Optional[List[Optional[Dict[int, Logprob]]]] = None
class DeltaMessage(OpenAIBaseModel):
......@@ -671,7 +704,7 @@ class BatchRequestInput(OpenAIBaseModel):
url: str
# The parameters of the request.
body: ChatCompletionRequest
body: Union[ChatCompletionRequest, EmbeddingRequest]
class BatchResponseData(OpenAIBaseModel):
......@@ -682,7 +715,7 @@ class BatchResponseData(OpenAIBaseModel):
request_id: str
# The body of the response.
body: Optional[ChatCompletionResponse] = None
body: Optional[Union[ChatCompletionResponse, EmbeddingResponse]] = None
class BatchRequestOutput(OpenAIBaseModel):
......
......@@ -7,8 +7,14 @@ from vllm.lora.request import LoRARequest
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
# Success string used for RPC instructions.
VLLM_RPC_SUCCESS_STR = "SUCCESS"
VLLM_RPC_HEALTHY_STR = "HEALTHY"
# Minimum value of ZMQ.SOCKET_LIMIT to run mp.
VLLM_RPC_SOCKET_LIMIT_CUTOFF = 2000
# HWM is set to Infinity.
VLLM_RPC_ZMQ_HWM = 0
@dataclass
......@@ -34,8 +40,10 @@ class RPCUtilityRequest(Enum):
GET_SCHEDULER_CONFIG = 5
GET_LORA_CONFIG = 6
DO_LOG_STATS = 7
CHECK_HEALTH = 8
IS_SERVER_HEALTHY = 8
IS_TRACING_ENABLED = 9
START_PROFILE = 10
STOP_PROFILE = 11
RPC_REQUEST_TYPE = Union[RPCGenerateRequest, RPCAbortRequest,
......
from contextlib import contextmanager
from typing import Any, AsyncIterator, Mapping, Optional
import asyncio
from contextlib import contextmanager, suppress
from typing import Any, AsyncGenerator, Mapping, Optional
from uuid import uuid4
import cloudpickle
import zmq
......@@ -7,29 +9,152 @@ import zmq.asyncio
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig)
# yapf: disable
from vllm.entrypoints.openai.rpc import (RPC_REQUEST_TYPE,
VLLM_RPC_HEALTHY_STR,
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
VLLM_RPC_SOCKET_LIMIT_CUTOFF,
VLLM_RPC_SUCCESS_STR,
VLLM_RPC_ZMQ_HWM, RPCAbortRequest,
RPCGenerateRequest, RPCUtilityRequest)
# yapf: enable
from vllm.envs import VLLM_RPC_GET_DATA_TIMEOUT_MS
from vllm.inputs import PromptInputs
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
logger = init_logger(__name__)
# Path used for inprocess proxy.
INPROC_PROXY_PATH = f"inproc://{uuid4()}"
class RPCClientClosedError(Exception):
"""Exception class raised when the client is used post-close.
The client can be closed, which closes the ZMQ context. This normally
happens on server shutdown. In some cases, methods like abort and
do_log_stats will still be called and then try to open a socket, which
causes a ZMQError and creates a huge stack trace.
So, we throw this error such that we can suppress it.
"""
class AsyncEngineRPCClient:
def __init__(self, port: int):
class AsyncEngineRPCClient:
"""
RPCClient that connects to the RPCServer wrapping AsyncLLMEngine.
The overall design mirrors the Asynchronous Client Server Pattern
https://zguide.zeromq.org/docs/chapter3/#The-Asynchronous-Client-Server-Pattern
On startup, the RPCClient:
- makes DEALER socket (to_rpc_server) that connects to the RPCServer
via ipc, which uses unix sockets under the hood
(https://libzmq.readthedocs.io/en/zeromq4-1/zmq_ipc.html)
- makes ROUTER socket (from_api_server) that binds to a random
inproc address, which uses memory under the hood
(https://libzmq.readthedocs.io/en/zeromq3-x/zmq_inproc.html)
- runs a proxy in a background asyncio task between
from_api_server (ROUTER, inproc) and to_rpc_server (DEALER ipc, )
Each request handled by the asyncio api_server calls generate():
- make a DEALER socket that connects to from_api_server via inproc
- send a RCPGenerateRequest to the inproc socket
- background proxy forwards the request from inproc -> ipc
- RPCServer responds to the request one token at a time over ipc
- background proxy forwards the response from ipc -> inproc
The connection looks like this:
DEALER <- inproc -> [ ROUTER | DEALER ] <- ipc -> DEALER
Message routing is performed via identities that are managed by the
ROUTER socket. ROUTER sockets track every connection it has and
tells the caller about these. The way it tells the caller is to stick
the connection identity in front of each message received. When we
send the message via a ROUTER, we first send an identity frame.
See https://zguide.zeromq.org/docs/chapter3/#The-Extended-Reply-Envelope
for more details on connection identities.
This proxy design enables us to use a single unix socket, which
improves performance by avoiding syscalls (~5%) and avoids resource limits
such as ulimit, which defaults to 1024 on ubuntu.
Note: we run set_hwm(0) on each socket, which sets the HWM to inf,
which is required to avoid dropping messages under high load.
This is generally not advisable. However, since we are in control
of both sides of the connection + failure on either side is
catastrophic to the overall system health and memory profiling
suggests limited memory overhead relative to asyncio, we will
proceed for now.
See https://zguide.zeromq.org/docs/chapter2/#High-Water-Marks
for more details on high water marks.
"""
def __init__(self, rpc_path: str):
self.context = zmq.asyncio.Context()
self.path = f"tcp://localhost:{port}"
self._data_timeout = VLLM_RPC_GET_DATA_TIMEOUT_MS
self._errored = False
# Maximum number of sockets that can be opened (typically 65536).
# ZMQ_SOCKET_LIMIT (http://api.zeromq.org/4-2:zmq-ctx-get)
socket_limit = self.context.get(zmq.constants.SOCKET_LIMIT)
if socket_limit < VLLM_RPC_SOCKET_LIMIT_CUTOFF:
raise ValueError(
f"Found zmq.constants.SOCKET_LIMIT={socket_limit}, which caps "
"the number of concurrent requests vLLM can process. Launch "
"vLLM with --disable-frontend-multiprocessing and open a "
"GitHub issue so we can investigate.")
# We only have 1 ipc connection that uses unix sockets, so
# safe to set MAX_SOCKETS to the zmq SOCKET_LIMIT (i.e. will
# not run into ulimit issues)
self.context.set(zmq.constants.MAX_SOCKETS, socket_limit)
# IPC connection to RPC Server (uses unix sockets).
self.to_rpc_server = self.context.socket(zmq.constants.DEALER)
self.to_rpc_server.set_hwm(VLLM_RPC_ZMQ_HWM)
self.to_rpc_server.bind(rpc_path)
# In process proxy to RPC Server (uses memory-based messaging).
self.from_api_server = self.context.socket(zmq.constants.ROUTER)
self.from_api_server.set_hwm(VLLM_RPC_ZMQ_HWM)
self.from_api_server.bind(INPROC_PROXY_PATH)
# Asyncio background task for the proxy.
self.proxy_task = asyncio.create_task(
self.run_proxy(self.from_api_server, self.to_rpc_server))
# Since we open 1 inproc socket per request, we have a hard cap on
# the number of requests that can run in vLLM w. frontend
# mulitprocessing. This value is used uvicorn to launch
# with --limit-concurrency to return 503 when server is overloaded.
# We need 2 sockets per request - 2:
# 1 for generate(), 1 for abort(), do_log_stats(), check_health()
self.limit_concurrency = socket_limit // 2 - 2
async def run_proxy(self, socket_from, socket_to):
"""Background task that runs a proxy"""
poller = zmq.asyncio.Poller()
poller.register(socket_from, zmq.constants.POLLIN)
poller.register(socket_to, zmq.constants.POLLIN)
while True:
events = await poller.poll()
events = dict(events)
if socket_from in events:
identity, msg = await socket_from.recv_multipart()
await socket_to.send_multipart([identity, msg])
if socket_to in events:
identity, msg = await socket_to.recv_multipart()
await socket_from.send_multipart([identity, msg])
async def setup(self):
"""Setup the client before it starts sending server requests."""
# Wait until server is ready.
await self.wait_for_server()
await self._wait_for_server_rpc()
# Get the configs.
self.model_config = await self._get_model_config_rpc()
......@@ -47,59 +172,100 @@ class AsyncEngineRPCClient:
def close(self):
"""Destroy the ZeroMQ Context."""
# Close all sockets associated with this context and
# then terminate the context.
self.from_api_server.close()
self.to_rpc_server.close()
self.context.destroy()
@contextmanager
def socket(self):
# Ensure client sockets are always closed after use
def to_proxy_socket(self):
# Connect to the RPCServer via the proxy.
# Raise a sensible error if the client was already closed.
# This can happen if a server shutdown is triggered but some coroutines
# are still running requests.
# There should not be a race condition with this check because we don't
# yield to the event loop between here and opening the socket.
if self.context.closed:
raise RPCClientClosedError("The ZMQ client has already shut down")
# Connect to RPC socket for Request-Reply pattern,
# Note that we use DEALER to enable asynchronous communication
# to enable streaming.
socket = self.context.socket(zmq.constants.DEALER)
socket.set_hwm(VLLM_RPC_ZMQ_HWM)
try:
socket.connect(self.path)
socket.connect(INPROC_PROXY_PATH)
yield socket
finally:
socket.close()
socket.close(linger=0)
async def _send_get_data_rpc_request(self, request: RPCUtilityRequest,
expected_type: Any,
error_message: str) -> Any:
"""Send an RPC request that is expecting data back."""
with self.socket() as socket:
with self.to_proxy_socket() as socket:
# Ping RPCServer with a request.
await socket.send(cloudpickle.dumps(request))
await socket.send_multipart([cloudpickle.dumps(request)])
# Make sure the server responds
if await socket.poll(timeout=self._data_timeout) == 0:
raise TimeoutError("Server didn't reply within "
f"{self._data_timeout} ms")
# Await the data from the Server.
data = cloudpickle.loads(await socket.recv())
if isinstance(data, Exception):
# Re-raise exceptions returned by the server
raise data
if not isinstance(data, expected_type):
# LoRAConfig can be None.
if expected_type == LoRAConfig and data is None:
pass
elif isinstance(data, Exception):
logger.error(error_message)
raise data
else:
raise ValueError(error_message)
return data
async def _send_one_way_rpc_request(self, request: RPC_REQUEST_TYPE,
error_message: str):
async def _send_one_way_rpc_request(
self,
request: RPC_REQUEST_TYPE,
error_message: str,
socket: Optional[zmq.asyncio.Socket] = None):
"""Send one-way RPC request to trigger an action."""
with self.socket() as socket:
# Ping RPC Server with request.
await socket.send(cloudpickle.dumps(request))
# Await acknowledgement from RPCServer.
response = cloudpickle.loads(await socket.recv())
async def do_rpc_call(socket: zmq.asyncio.Socket,
request: RPC_REQUEST_TYPE):
await socket.send_multipart([cloudpickle.dumps(request)])
if await socket.poll(timeout=self._data_timeout) == 0:
raise TimeoutError("Server didn't reply within "
f"{self._data_timeout} ms")
return cloudpickle.loads(await socket.recv())
# Make a new socket connection.
if socket is None:
with self.to_proxy_socket() as socket:
response = await do_rpc_call(socket, request)
# Use existing socket connection.
else:
response = await do_rpc_call(socket, request)
if not isinstance(response, str) or response != VLLM_RPC_SUCCESS_STR:
if isinstance(response, Exception):
logger.error(error_message)
raise response
raise ValueError(error_message)
return response
async def get_tokenizer(self, lora_request: LoRARequest):
return await self.tokenizer.get_lora_tokenizer_async(lora_request)
......@@ -112,12 +278,12 @@ class AsyncEngineRPCClient:
async def is_tracing_enabled(self) -> bool:
return self.tracing_flag
async def wait_for_server(self):
async def _wait_for_server_rpc(self):
"""Wait for the RPCServer to start up."""
await self._send_one_way_rpc_request(
request=RPCUtilityRequest.IS_SERVER_READY,
error_message="Unable to start RPC Server.")
error_message="Unable to start RPC Server")
async def _get_model_config_rpc(self) -> ModelConfig:
"""Get the ModelConfig object from the RPC Server"""
......@@ -151,7 +317,7 @@ class AsyncEngineRPCClient:
expected_type=SchedulerConfig,
error_message="Could not get SchedulerConfig from RPC Server")
async def _get_lora_config_rpc(self):
async def _get_lora_config_rpc(self) -> LoRAConfig:
"""Get LoRAConfig from the RPCServer"""
return await self._send_get_data_rpc_request(
......@@ -159,29 +325,51 @@ class AsyncEngineRPCClient:
expected_type=LoRAConfig,
error_message="Could not get LoRAConfig from RPC Server")
async def _is_tracing_enabled_rpc(self) -> ParallelConfig:
async def _is_tracing_enabled_rpc(self) -> bool:
"""Get is_tracing_enabled flag from the RPCServer"""
return await self._send_get_data_rpc_request(
RPCUtilityRequest.IS_TRACING_ENABLED,
expected_type=bool,
error_message="Could not get is_tracing_enabled flag from RPC "
"Server")
error_message="Could not get is_tracing_enabled from RPC Server")
async def abort(self, request_id: str):
"""Send an ABORT_REQUEST signal to the RPC Server"""
# Suppress timeouts as well.
# In cases where the server is busy processing requests and a very
# large volume of abort requests arrive, it is likely that the server
# will not be able to ack all of them in time. We have seen this when
# we abort 20k requests at once while another 2k are processing- many
# of them time out, but we see the server successfully abort all of the
# requests.
# In this case we assume that the server has received or will receive
# these abort requests, and ignore the timeout. This prevents a massive
# wall of `TimeoutError` stack traces.
with suppress(RPCClientClosedError, TimeoutError):
await self._send_one_way_rpc_request(
request=RPCAbortRequest(request_id),
error_message=f"RPCAbortRequest {request_id} failed")
async def do_log_stats(self):
"""Send a DO_LOG_STATS signal to the RPC Server"""
with suppress(RPCClientClosedError):
await self._send_one_way_rpc_request(
request=RPCUtilityRequest.DO_LOG_STATS,
error_message="RPCRequest DO_LOG_STATS failed.")
@property
def is_running(self) -> bool:
return not self._errored
@property
def is_stopped(self) -> bool:
return self._errored
@property
def errored(self) -> bool:
return self._errored
async def generate(
self,
inputs: PromptInputs,
......@@ -190,11 +378,12 @@ class AsyncEngineRPCClient:
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
) -> AsyncIterator[RequestOutput]:
) -> AsyncGenerator[RequestOutput, None]:
"""Send an RPCGenerateRequest to the RPCServer and stream responses."""
with self.socket() as socket:
finished = False
try:
with self.to_proxy_socket() as socket:
# Send RPCGenerateRequest to the RPCServer.
await socket.send_multipart([
cloudpickle.dumps(
......@@ -208,41 +397,57 @@ class AsyncEngineRPCClient:
])
# Stream back the results from the RPC Server.
while True:
while not finished:
message = await socket.recv()
request_output = cloudpickle.loads(message)
if isinstance(request_output, Exception):
# On exception, check if the server is still healthy
# possibly setting the `errored` property.
if not self._errored:
try:
await self.check_health(socket=socket)
except Exception as e:
self._errored = True
logger.exception(repr(e))
# NB: do before raising here so that the flag is set
# by the time the caller receives this exception
raise request_output
if request_output.finished:
break
finished = request_output.finished
yield request_output
yield request_output
finally:
# Request was canceled by the client.
if not finished and not self._errored:
await self.abort(request_id)
async def check_health(self) -> None:
async def check_health(self,
socket: Optional[zmq.asyncio.Socket] = None
) -> None:
"""Raise if unhealthy"""
with self.socket() as socket:
await self._send_one_way_rpc_request(
request=RPCUtilityRequest.IS_SERVER_HEALTHY,
error_message="Got Unhealthy response from RPC Server",
socket=socket)
# Ping RPCServer with CHECK_HEALTH request.
await socket.send(cloudpickle.dumps(RPCUtilityRequest.CHECK_HEALTH)
)
async def encode(self, *args,
**kwargs) -> AsyncGenerator[EmbeddingRequestOutput, None]:
raise NotImplementedError(
"Embeddings not supported with multiprocessing backend")
# Await the reply from the server.
# TODO: do we need an internal timeout here?
# Or do we expect the external probe to timeout and let this chill?
health_message = cloudpickle.loads(await socket.recv())
async def start_profile(self) -> None:
"""Start profiling the engine"""
if isinstance(health_message, Exception):
raise health_message
await self._send_one_way_rpc_request(
request=RPCUtilityRequest.START_PROFILE,
error_message="RPCRequest START_PROFILE failed.")
if health_message != VLLM_RPC_HEALTHY_STR:
raise ValueError("Expected healthy response from backend but got "
"f{health_message}")
async def stop_profile(self) -> None:
"""Stop profiling the engine"""
async def encode(self, *args,
**kwargs) -> AsyncIterator[EmbeddingRequestOutput]:
raise NotImplementedError(
"Embeddings not supported with multiprocessing backend")
await self._send_one_way_rpc_request(
request=RPCUtilityRequest.STOP_PROFILE,
error_message="RPCRequest STOP_PROFILE failed.")
\ No newline at end of file
import asyncio
import signal
from typing import Any, Coroutine
from typing import Any, Coroutine, Union
import cloudpickle
import uvloop
import zmq
import zmq.asyncio
from typing_extensions import Never
from vllm import AsyncEngineArgs, AsyncLLMEngine
from vllm.entrypoints.openai.rpc import (VLLM_RPC_HEALTHY_STR,
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig)
from vllm.entrypoints.openai.rpc import (VLLM_RPC_SUCCESS_STR,
VLLM_RPC_ZMQ_HWM, RPCAbortRequest,
RPCGenerateRequest, RPCUtilityRequest)
from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext
logger = init_logger(__name__)
CONFIG_TYPE = Union[ModelConfig, DecodingConfig, ParallelConfig,
SchedulerConfig, LoRAConfig]
class AsyncEngineRPCServer:
def __init__(self, async_engine_args: AsyncEngineArgs,
usage_context: UsageContext, port: int):
usage_context: UsageContext, rpc_path: str):
# Initialize engine first.
self.engine = AsyncLLMEngine.from_engine_args(async_engine_args,
usage_context)
self.engine = AsyncLLMEngine.from_engine_args(
async_engine_args, usage_context=usage_context)
# Initialize context.
self.context = zmq.asyncio.Context()
# Init socket for readiness state.
self.socket = self.context.socket(zmq.constants.ROUTER)
# Note numeric form of localhost should be used for zmq bind(),
# see https://stackoverflow.com/a/8958414
self.socket.bind(f"tcp://127.0.0.1:{port}")
# Init socket.
self.socket = self.context.socket(zmq.constants.DEALER)
self.socket.set_hwm(VLLM_RPC_ZMQ_HWM)
self.socket.connect(rpc_path)
def cleanup(self):
"""Cleanup all resources."""
self.socket.close()
self.context.destroy()
self.engine.shutdown_background_loop()
# Clear the engine reference so that it can be GC'ed.
del self.engine
async def get_model_config(self, identity):
"""Send the ModelConfig"""
model_config = await self.engine.get_model_config()
await self.socket.send_multipart(
[identity, cloudpickle.dumps(model_config)])
async def get_decoding_config(self, identity):
"""Send the DecodingConfig"""
decoding_config = await self.engine.get_decoding_config()
await self.socket.send_multipart(
[identity, cloudpickle.dumps(decoding_config)])
async def get_lora_config(self, identity):
lora_config = await self.engine.get_lora_config()
await self.socket.send_multipart(
[identity, cloudpickle.dumps(lora_config)])
async def get_scheduler_config(self, identity):
"""Send the SchedulerConfig"""
parallel_config = await self.engine.get_scheduler_config()
async def get_config(self, identity, request):
try:
config: CONFIG_TYPE
if request == RPCUtilityRequest.GET_MODEL_CONFIG:
config = await self.engine.get_model_config()
elif request == RPCUtilityRequest.GET_DECODING_CONFIG:
config = await self.engine.get_decoding_config()
elif request == RPCUtilityRequest.GET_LORA_CONFIG:
config = await self.engine.get_lora_config()
elif request == RPCUtilityRequest.GET_SCHEDULER_CONFIG:
config = await self.engine.get_scheduler_config()
elif request == RPCUtilityRequest.GET_PARALLEL_CONFIG:
config = await self.engine.get_parallel_config()
else:
raise ValueError("Unknown Config Request: %s", request)
await self.socket.send_multipart(
[identity, cloudpickle.dumps(parallel_config)])
async def get_parallel_config(self, identity):
"""Send the ParallelConfig"""
parallel_config = await self.engine.get_parallel_config()
[identity, cloudpickle.dumps(config)])
await self.socket.send_multipart(
[identity, cloudpickle.dumps(parallel_config)])
except Exception as e:
await self.socket.send_multipart([identity, cloudpickle.dumps(e)])
async def is_tracing_enabled(self, identity):
"""Send the is_tracing_enabled flag"""
......@@ -84,28 +80,23 @@ class AsyncEngineRPCServer:
"""Log stats and confirm success."""
await self.engine.do_log_stats()
await self.socket.send_multipart([
identity,
cloudpickle.dumps(VLLM_RPC_SUCCESS_STR),
])
await self.socket.send_multipart(
[identity, cloudpickle.dumps(VLLM_RPC_SUCCESS_STR)])
async def is_server_ready(self, identity):
"""Notify the client that we are ready."""
await self.socket.send_multipart([
identity,
cloudpickle.dumps(VLLM_RPC_SUCCESS_STR),
])
await self.socket.send_multipart(
[identity, cloudpickle.dumps(VLLM_RPC_SUCCESS_STR)])
async def abort(self, identity, request: RPCAbortRequest):
"""Abort request and notify the client of success."""
try:
# Abort the request in the llm engine.
await self.engine.abort(request.request_id)
# Send confirmation to the client.
await self.socket.send_multipart([
identity,
cloudpickle.dumps(VLLM_RPC_SUCCESS_STR),
])
result: Union[str, Exception] = VLLM_RPC_SUCCESS_STR
except Exception as e:
result = e
await self.socket.send_multipart([identity, cloudpickle.dumps(result)])
async def generate(self, identity, generate_request: RPCGenerateRequest):
try:
......@@ -122,17 +113,37 @@ class AsyncEngineRPCServer:
[identity, cloudpickle.dumps(request_output)])
except Exception as e:
### Notify client of all failures
await self.socket.send_multipart([identity, cloudpickle.dumps(e)])
async def check_health(self, identity):
try:
await self.engine.check_health()
await self.socket.send_multipart(
[identity, cloudpickle.dumps(VLLM_RPC_HEALTHY_STR)])
[identity, cloudpickle.dumps(VLLM_RPC_SUCCESS_STR)])
except Exception as e:
await self.socket.send_multipart([identity, cloudpickle.dumps(e)])
async def start_profile(self, identity):
logger.info("Starting profiler...")
await self.engine.start_profile()
logger.info("Profiler started.")
await self.socket.send_multipart([
identity,
cloudpickle.dumps(VLLM_RPC_SUCCESS_STR),
])
async def stop_profile(self, identity):
logger.info("Stopping profiler...")
await self.engine.stop_profile()
logger.info("Profiler stopped.")
await self.socket.send_multipart([
identity,
cloudpickle.dumps(VLLM_RPC_SUCCESS_STR),
])
def _make_handler_coro(self, identity,
message) -> Coroutine[Any, Any, Never]:
"""Route the zmq message to the handler coroutine."""
......@@ -146,24 +157,26 @@ class AsyncEngineRPCServer:
return self.abort(identity, request)
elif isinstance(request, RPCUtilityRequest):
if request == RPCUtilityRequest.GET_MODEL_CONFIG:
return self.get_model_config(identity)
elif request == RPCUtilityRequest.GET_PARALLEL_CONFIG:
return self.get_parallel_config(identity)
elif request == RPCUtilityRequest.GET_DECODING_CONFIG:
return self.get_decoding_config(identity)
elif request == RPCUtilityRequest.GET_SCHEDULER_CONFIG:
return self.get_scheduler_config(identity)
elif request == RPCUtilityRequest.GET_LORA_CONFIG:
return self.get_lora_config(identity)
if request in [
RPCUtilityRequest.GET_MODEL_CONFIG,
RPCUtilityRequest.GET_PARALLEL_CONFIG,
RPCUtilityRequest.GET_DECODING_CONFIG,
RPCUtilityRequest.GET_SCHEDULER_CONFIG,
RPCUtilityRequest.GET_LORA_CONFIG
]:
return self.get_config(identity, request)
elif request == RPCUtilityRequest.DO_LOG_STATS:
return self.do_log_stats(identity)
elif request == RPCUtilityRequest.IS_SERVER_READY:
return self.is_server_ready(identity)
elif request == RPCUtilityRequest.CHECK_HEALTH:
elif request == RPCUtilityRequest.IS_SERVER_HEALTHY:
return self.check_health(identity)
elif request == RPCUtilityRequest.IS_TRACING_ENABLED:
return self.is_tracing_enabled(identity)
elif request == RPCUtilityRequest.START_PROFILE:
return self.start_profile(identity)
elif request == RPCUtilityRequest.STOP_PROFILE:
return self.stop_profile(identity)
else:
raise ValueError(f"Unknown RPCUtilityRequest type: {request}")
......@@ -213,6 +226,6 @@ async def run_server(server: AsyncEngineRPCServer):
def run_rpc_server(async_engine_args: AsyncEngineArgs,
usage_context: UsageContext, port: int):
server = AsyncEngineRPCServer(async_engine_args, usage_context, port)
asyncio.run(run_server(server))
usage_context: UsageContext, rpc_path: str):
server = AsyncEngineRPCServer(async_engine_args, usage_context, rpc_path)
uvloop.run(run_server(server))
import asyncio
from io import StringIO
from typing import Awaitable, List
from typing import Awaitable, Callable, List
import aiohttp
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.logger import RequestLogger
# yapf: disable
from vllm.entrypoints.openai.protocol import (BatchRequestInput,
BatchRequestOutput,
BatchResponseData,
ChatCompletionResponse,
ErrorResponse)
EmbeddingResponse, ErrorResponse)
# yapf: enable
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser, random_uuid
......@@ -82,27 +85,26 @@ async def write_file(path_or_url: str, data: str) -> None:
f.write(data)
async def run_request(chat_serving: OpenAIServingChat,
async def run_request(serving_engine_func: Callable,
request: BatchRequestInput) -> BatchRequestOutput:
chat_request = request.body
chat_response = await chat_serving.create_chat_completion(chat_request)
response = await serving_engine_func(request.body)
if isinstance(chat_response, ChatCompletionResponse):
if isinstance(response, (ChatCompletionResponse, EmbeddingResponse)):
batch_output = BatchRequestOutput(
id=f"vllm-{random_uuid()}",
custom_id=request.custom_id,
response=BatchResponseData(
body=chat_response, request_id=f"vllm-batch-{random_uuid()}"),
body=response, request_id=f"vllm-batch-{random_uuid()}"),
error=None,
)
elif isinstance(chat_response, ErrorResponse):
elif isinstance(response, ErrorResponse):
batch_output = BatchRequestOutput(
id=f"vllm-{random_uuid()}",
custom_id=request.custom_id,
response=BatchResponseData(
status_code=chat_response.code,
status_code=response.code,
request_id=f"vllm-batch-{random_uuid()}"),
error=chat_response,
error=response,
)
else:
raise ValueError("Request must not be sent in stream mode")
......@@ -128,6 +130,7 @@ async def main(args):
else:
request_logger = RequestLogger(max_log_len=args.max_log_len)
# Create the openai serving objects.
openai_serving_chat = OpenAIServingChat(
engine,
model_config,
......@@ -138,12 +141,35 @@ async def main(args):
request_logger=request_logger,
chat_template=None,
)
openai_serving_embedding = OpenAIServingEmbedding(
engine,
model_config,
served_model_names,
request_logger=request_logger,
)
# Submit all requests in the file to the engine "concurrently".
response_futures: List[Awaitable[BatchRequestOutput]] = []
for request_json in (await read_file(args.input_file)).strip().split("\n"):
# Skip empty lines.
request_json = request_json.strip()
if not request_json:
continue
request = BatchRequestInput.model_validate_json(request_json)
response_futures.append(run_request(openai_serving_chat, request))
# Determine the type of request and run it.
if request.url == "/v1/chat/completions":
response_futures.append(
run_request(openai_serving_chat.create_chat_completion,
request))
elif request.url == "/v1/embeddings":
response_futures.append(
run_request(openai_serving_embedding.create_embedding,
request))
else:
raise ValueError("Only /v1/chat/completions and /v1/embeddings are"
"supported in the batch endpoint.")
responses = await asyncio.gather(*response_futures)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment