Commit afd0da21 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.7.1' into v0.7.1-dev

parents 1a11f127 4f4d427a
......@@ -25,7 +25,10 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
IPC_HEALTH_EXT, IPC_INPUT_EXT,
IPC_OUTPUT_EXT, RPC_REQUEST_T,
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
RPCError, RPCProcessRequest,
RPCAdapterLoadedResponse, RPCError,
RPCLoadAdapterRequest,
RPCProcessRequest,
RPCResetPrefixCacheRequest,
RPCStartupRequest, RPCStartupResponse,
RPCUProfileRequest)
from vllm.engine.protocol import EngineClient
......@@ -240,22 +243,34 @@ class MQLLMEngineClient(EngineClient):
queue = self.output_queues.get(request_id)
if queue is not None:
queue.put_nowait(exception)
# Put each output into the appropriate queue.
elif isinstance(request_outputs, RPCAdapterLoadedResponse):
self._add_output(request_outputs)
else:
# Put each output into the appropriate steam.
for request_output in request_outputs:
queue = self.output_queues.get(
request_output.request_id)
if queue is not None:
queue.put_nowait(request_output)
self._add_output(request_output)
except asyncio.CancelledError:
logger.debug("Shutting down MQLLMEngineClient output handler.")
def _add_output(self, request_output: Union[RequestOutput,
RPCAdapterLoadedResponse]):
queue = self.output_queues.get(request_output.request_id)
if queue is not None:
queue.put_nowait(request_output)
async def setup(self):
"""Setup the client before it starts sending server requests."""
# Start output_loop
self.output_loop = asyncio.create_task(self.run_output_handler_loop())
if self.output_loop is None:
# only generate once to avoid multiple concurrent output_loops
# this will lead to race conditions and wrong orders of tokens
# returned by the engine
# setup will be called multiple times during the startup of
# the engine
self.output_loop = asyncio.create_task(
self.run_output_handler_loop())
with self.get_data_socket() as socket:
# Wait until server is ready.
......@@ -264,6 +279,7 @@ class MQLLMEngineClient(EngineClient):
self.tracing_flag = response.tracing_enabled
# Start health_loop.
if self.health_loop is None:
self.health_loop = asyncio.create_task(
self.run_heartbeat_loop(timeout=VLLM_RPC_TIMEOUT))
......@@ -659,3 +675,31 @@ class MQLLMEngineClient(EngineClient):
await self._send_one_way_rpc_request(
request=RPCUProfileRequest.STOP_PROFILE, socket=self.input_socket)
async def reset_prefix_cache(self) -> None:
"""Reset the prefix cache"""
await self._send_one_way_rpc_request(
request=RPCResetPrefixCacheRequest.RESET_PREFIX_CACHE,
socket=self.input_socket)
async def add_lora(self, lora_request: LoRARequest) -> None:
"""Load a new LoRA adapter into the engine for future requests."""
# Uses the same I/O as generate requests
request = RPCLoadAdapterRequest(lora_request)
# Create output queue for this requests.
queue: asyncio.Queue[Union[None, BaseException]] = asyncio.Queue()
self.output_queues[request.request_id] = queue
# Send the request
request_bytes = pickle.dumps(request)
await self.input_socket.send_multipart((request_bytes, ), copy=False)
# Wait for the response
request_output = await queue.get()
self.output_queues.pop(request.request_id)
# Raise on error, otherwise happily return None
if isinstance(request_output, BaseException):
raise request_output
......@@ -14,11 +14,13 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
IPC_HEALTH_EXT, IPC_INPUT_EXT,
IPC_OUTPUT_EXT, REQUEST_OUTPUTS_T,
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
RPCError, RPCProcessRequest,
RPCAdapterLoadedResponse, RPCError,
RPCLoadAdapterRequest,
RPCProcessRequest,
RPCResetPrefixCacheRequest,
RPCStartupRequest, RPCStartupResponse,
RPCUProfileRequest)
# yapf: enable
from vllm.executor.gpu_executor import GPUExecutor
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.usage.usage_lib import UsageContext
......@@ -234,6 +236,10 @@ class MQLLMEngine:
self.start_profile()
else:
self.stop_profile()
elif isinstance(request, RPCLoadAdapterRequest):
self._handle_load_adapter_request(request)
elif isinstance(request, RPCResetPrefixCacheRequest):
self.reset_prefix_cache()
else:
raise ValueError("Unknown RPCRequest Type: "
f"{type(request)}")
......@@ -284,6 +290,20 @@ class MQLLMEngine:
if self.log_requests:
logger.info("Aborted request %s.", request.request_id)
def _handle_load_adapter_request(self, request: RPCLoadAdapterRequest):
try:
self.engine.add_lora(request.lora_request)
except BaseException as e:
# Send back an error if the adater fails to load
rpc_err = RPCError(request_id=request.request_id,
is_engine_errored=False,
exception=e)
self._send_outputs(rpc_err)
return
# Otherwise, send back the successful load message
self._send_outputs(
RPCAdapterLoadedResponse(request_id=request.request_id))
def _health_check(self):
# Send unhealthy if engine has already errored
if self._errored_with is not None:
......@@ -296,7 +316,11 @@ class MQLLMEngine:
self._send_unhealthy(e)
def _send_outputs(self, outputs: REQUEST_OUTPUTS_T):
"""Send List of RequestOutput to RPCClient."""
"""Send outputs back to the engine client. These can be:
- Exceptions
- A list of generation outputs
- A response from loading a lora adapter
"""
if outputs:
try:
from ray.exceptions import RayTaskError
......@@ -335,16 +359,13 @@ class MQLLMEngine:
self._errored_with = e
def start_profile(self) -> None:
if type(self.engine.model_executor) is GPUExecutor:
self.engine.model_executor.start_profile()
else:
self.engine.model_executor._run_workers("start_profile")
self.engine.start_profile()
def stop_profile(self) -> None:
if type(self.engine.model_executor) is GPUExecutor:
self.engine.model_executor.stop_profile()
else:
self.engine.model_executor._run_workers("stop_profile")
self.engine.stop_profile()
def reset_prefix_cache(self) -> bool:
return self.engine.reset_prefix_cache()
def signal_handler(*_) -> None:
......
......@@ -65,7 +65,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
@staticmethod
@functools.lru_cache
def _log_prompt_logprob_unsupported_warning_once():
# Reminder: Please update docs/source/usage/compatibility_matrix.md
# Reminder: Please update docs/source/features/compatibility_matrix.md
# If the feature combo become valid
logger.warning(
"Prompt logprob is not supported by multi step workers. "
......@@ -144,7 +144,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
def _process_decode_and_stop(self, seq: Sequence,
sampling_params: SamplingParams) -> None:
new_char_count = 0
if sampling_params.detokenize:
if sampling_params.detokenize and self.detokenizer:
new_char_count = self.detokenizer.decode_sequence_inplace(
seq, sampling_params)
......
......@@ -102,9 +102,9 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
Args:
seq_group: the output is associated with this :class:`SequenceGroup`
output: the :class:`SequenceGroupOutput` for a single scheduler step
outputs: the :class:`SequenceGroupOutput` for a single scheduler step
"""
assert len(outputs) == 1, ("Single step should only has 1 output.")
assert len(outputs) == 1, "Single step should only have 1 output."
output = outputs[0]
assert isinstance(output, CompletionSequenceGroupOutput)
single_step_process_prompt_logprob(self, seq_group, output)
......
......@@ -270,3 +270,13 @@ class EngineClient(ABC):
async def stop_profile(self) -> None:
"""Start profiling the engine"""
...
@abstractmethod
async def reset_prefix_cache(self) -> None:
"""Reset the prefix cache"""
...
@abstractmethod
async def add_lora(self, lora_request: LoRARequest) -> None:
"""Load a new LoRA adapter into the engine for future requests."""
...
......@@ -3,10 +3,10 @@ import codecs
import json
from abc import ABC, abstractmethod
from collections import defaultdict, deque
from functools import lru_cache, partial
from functools import cache, lru_cache, partial
from pathlib import Path
from typing import (Any, Awaitable, Callable, Dict, Generic, Iterable, List,
Literal, Mapping, Optional, Tuple, TypeVar, Union, cast)
Literal, Optional, Tuple, TypeVar, Union, cast)
import jinja2.nodes
import transformers.utils.chat_template_utils as hf_chat_utils
......@@ -23,6 +23,8 @@ from openai.types.chat import (
ChatCompletionMessageParam as OpenAIChatCompletionMessageParam)
from openai.types.chat import (ChatCompletionMessageToolCallParam,
ChatCompletionToolMessageParam)
from openai.types.chat.chat_completion_content_part_input_audio_param import (
InputAudio)
# yapf: enable
# pydantic needs the TypedDict from typing_extensions
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
......@@ -31,13 +33,8 @@ from typing_extensions import Required, TypeAlias, TypedDict
from vllm.config import ModelConfig
from vllm.logger import init_logger
from vllm.multimodal import MultiModalDataDict
from vllm.multimodal.utils import (async_get_and_parse_audio,
async_get_and_parse_image,
async_get_and_parse_video,
get_and_parse_audio, get_and_parse_image,
get_and_parse_video)
from vllm.multimodal.utils import MediaConnector
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import print_warning_once
logger = init_logger(__name__)
......@@ -368,16 +365,19 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
self._tokenizer = tokenizer
self._allowed_items = (model_config.multimodal_config.limit_per_prompt
if model_config.multimodal_config else {})
self._consumed_items = {k: 0 for k in self._allowed_items}
self._items: List[_T] = []
self._items_by_modality = defaultdict[str, list[_T]](list)
@property
def model_config(self) -> ModelConfig:
return self._model_config
@property
def allowed_local_media_path(self):
return self._model_config.allowed_local_media_path
@staticmethod
@lru_cache(maxsize=None)
@cache
def _cached_token_str(tokenizer: AnyTokenizer, token_index: int) -> str:
return tokenizer.decode(token_index)
......@@ -392,7 +392,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
if model_type == "phi3_v":
# Workaround since this token is not defined in the tokenizer
return f"<|image_{current_count}|>"
if model_type == "minicpmv":
if model_type in ("minicpmo", "minicpmv"):
return "(<image>./</image>)"
if model_type in ("blip-2", "chatglm", "fuyu", "paligemma",
"pixtral"):
......@@ -403,8 +403,8 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
if model_type.startswith("llava"):
return self._cached_token_str(self._tokenizer,
hf_config.image_token_index)
if model_type in ("chameleon", "internvl_chat", "NVLM_D",
"h2ovl_chat"):
if model_type in ("chameleon", "deepseek_vl_v2", "internvl_chat",
"NVLM_D", "h2ovl_chat"):
return "<image>"
if model_type == "mllama":
return "<|image|>"
......@@ -424,10 +424,14 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
if model_type == "qwen2_audio":
return (f"Audio {current_count}: "
f"<|audio_bos|><|AUDIO|><|audio_eos|>")
if model_type == "minicpmo":
return "(<audio>./</audio>)"
raise TypeError(f"Unknown model type: {model_type}")
elif modality == "video":
if model_type == "qwen2_vl":
return "<|vision_start|><|video_pad|><|vision_end|>"
if model_type in ("minicpmo", "minicpmv"):
return "(<video>./</video>)"
if model_type.startswith("llava"):
return self._cached_token_str(self._tokenizer,
hf_config.video_token_index)
......@@ -435,38 +439,19 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
else:
raise TypeError(f"Unknown modality: {modality}")
@staticmethod
def _combine(items: List[MultiModalDataDict]) -> MultiModalDataDict:
mm_lists: Mapping[str, List[object]] = defaultdict(list)
# Merge all the multi-modal items
for single_mm_data in items:
for mm_key, mm_item in single_mm_data.items():
if isinstance(mm_item, list):
mm_lists[mm_key].extend(mm_item)
else:
mm_lists[mm_key].append(mm_item)
# Unpack any single item lists for models that don't expect multiple.
return {
mm_key: mm_list[0] if len(mm_list) == 1 else mm_list
for mm_key, mm_list in mm_lists.items()
}
def add(self, modality: ModalityStr, item: _T) -> Optional[str]:
"""
Add a multi-modal item to the current prompt and returns the
placeholder string to use, if any.
"""
allowed_count = self._allowed_items.get(modality, 1)
current_count = self._consumed_items.get(modality, 0) + 1
current_count = len(self._items_by_modality[modality]) + 1
if current_count > allowed_count:
raise ValueError(
f"At most {allowed_count} {modality}(s) may be provided in "
"one request.")
self._consumed_items[modality] = current_count
self._items.append(item)
self._items_by_modality[modality].append(item)
return self._placeholder_str(modality, current_count)
......@@ -475,22 +460,26 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
raise NotImplementedError
class MultiModalItemTracker(BaseMultiModalItemTracker[MultiModalDataDict]):
class MultiModalItemTracker(BaseMultiModalItemTracker[object]):
def all_mm_data(self) -> Optional[MultiModalDataDict]:
return self._combine(self._items) if self._items else None
if self._items_by_modality:
return dict(self._items_by_modality)
return None
def create_parser(self) -> "BaseMultiModalContentParser":
return MultiModalContentParser(self)
class AsyncMultiModalItemTracker(
BaseMultiModalItemTracker[Awaitable[MultiModalDataDict]]):
class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]):
async def all_mm_data(self) -> Optional[MultiModalDataDict]:
if self._items:
items = await asyncio.gather(*self._items)
return self._combine(items)
if self._items_by_modality:
return {
modality: await asyncio.gather(*items)
for modality, items in self._items_by_modality.items()
}
return None
......@@ -522,7 +511,7 @@ class BaseMultiModalContentParser(ABC):
raise NotImplementedError
@abstractmethod
def parse_input_audio(self, input_audio: Dict[str, str]) -> None:
def parse_input_audio(self, input_audio: InputAudio) -> None:
raise NotImplementedError
@abstractmethod
......@@ -537,31 +526,31 @@ class MultiModalContentParser(BaseMultiModalContentParser):
self._tracker = tracker
self._connector = MediaConnector(
allowed_local_media_path=tracker.allowed_local_media_path,
)
def parse_image(self, image_url: str) -> None:
image = get_and_parse_image(image_url,
allowed_local_media_path=self._tracker.
_model_config.allowed_local_media_path)
image = self._connector.fetch_image(image_url)
placeholder = self._tracker.add("image", image)
self._add_placeholder(placeholder)
def parse_audio(self, audio_url: str) -> None:
audio = get_and_parse_audio(audio_url)
audio = self._connector.fetch_audio(audio_url)
placeholder = self._tracker.add("audio", audio)
self._add_placeholder(placeholder)
def parse_input_audio(self, input_audio: Dict[str, str]) -> None:
input_audio_data = input_audio.get("data","")
input_audio_format = input_audio.get("format","")
audio_url = f"data:audio/{input_audio_format};base64,{input_audio_data}"
audio = get_and_parse_audio(audio_url)
def parse_input_audio(self, input_audio: InputAudio) -> None:
audio_data = input_audio.get("data", "")
audio_format = input_audio.get("format", "")
audio_url = f"data:audio/{audio_format};base64,{audio_data}"
placeholder = self._tracker.add("audio", audio)
self._add_placeholder(placeholder)
return self.parse_audio(audio_url)
def parse_video(self, video_url: str) -> None:
video = get_and_parse_video(video_url)
video = self._connector.fetch_video(video_url)
placeholder = self._tracker.add("video", video)
self._add_placeholder(placeholder)
......@@ -573,33 +562,31 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
super().__init__()
self._tracker = tracker
self._connector = MediaConnector(
allowed_local_media_path=tracker.allowed_local_media_path,
)
def parse_image(self, image_url: str) -> None:
image_coro = async_get_and_parse_image(
image_url,
allowed_local_media_path=self._tracker._model_config.
allowed_local_media_path)
image_coro = self._connector.fetch_image_async(image_url)
placeholder = self._tracker.add("image", image_coro)
self._add_placeholder(placeholder)
def parse_audio(self, audio_url: str) -> None:
audio_coro = async_get_and_parse_audio(audio_url)
audio_coro = self._connector.fetch_audio_async(audio_url)
placeholder = self._tracker.add("audio", audio_coro)
self._add_placeholder(placeholder)
def parse_input_audio(self, input_audio: Dict[str, str]) -> None:
input_audio_data = input_audio.get("data","")
input_audio_format = input_audio.get("format","")
audio_url = f"data:audio/{input_audio_format};base64,{input_audio_data}"
audio_coro = async_get_and_parse_audio(audio_url)
def parse_input_audio(self, input_audio: InputAudio) -> None:
audio_data = input_audio.get("data", "")
audio_format = input_audio.get("format", "")
audio_url = f"data:audio/{audio_format};base64,{audio_data}"
placeholder = self._tracker.add("audio", audio_coro)
self._add_placeholder(placeholder)
return self.parse_audio(audio_url)
def parse_video(self, video_url: str) -> None:
video = async_get_and_parse_video(video_url)
video = self._connector.fetch_video_async(video_url)
placeholder = self._tracker.add("video", video)
self._add_placeholder(placeholder)
......@@ -695,10 +682,13 @@ _InputAudioParser = partial(cast, ChatCompletionContentPartInputAudioParam)
_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam)
_VideoParser = partial(cast, ChatCompletionContentPartVideoParam)
_ContentPart: TypeAlias = Union[str, Dict[str, str], InputAudio]
# Define a mapping from part types to their corresponding parsing functions.
MM_PARSER_MAP: Dict[str,
Callable[[ChatCompletionContentPartParam],
Union[str, Dict[str,str]]]] = {
MM_PARSER_MAP: Dict[
str,
Callable[[ChatCompletionContentPartParam], _ContentPart],
] = {
"text":
lambda part: _TextParser(part).get("text", ""),
"image_url":
......@@ -715,8 +705,7 @@ MM_PARSER_MAP: Dict[str,
def _parse_chat_message_content_mm_part(
part: ChatCompletionContentPartParam) -> Tuple[str,
Union[str, Dict[str, str]]]:
part: ChatCompletionContentPartParam) -> tuple[str, _ContentPart]:
"""
Parses a given multi-modal content part based on its type.
......@@ -783,7 +772,7 @@ def _parse_chat_message_content_parts(
*,
wrap_dicts: bool,
) -> List[ConversationMessage]:
content: List[Union[str, Dict[str, str]]] = []
content = list[_ContentPart]()
mm_parser = mm_tracker.create_parser()
......@@ -814,7 +803,7 @@ def _parse_chat_message_content_part(
mm_parser: BaseMultiModalContentParser,
*,
wrap_dicts: bool,
) -> Optional[Union[str, Dict[str, str]]]:
) -> Optional[_ContentPart]:
"""Parses a single part of a conversation. If wrap_dicts is True,
structured dictionary pieces for texts and images will be
wrapped in dictionaries, i.e., {"type": "text", "text", ...} and
......@@ -823,8 +812,7 @@ def _parse_chat_message_content_part(
with multimodal placeholders.
"""
if isinstance(part, str): # Handle plain text parts
text = _TextParser(part)
return text
return part
# Handle structured dictionary parts
part_type, content = _parse_chat_message_content_mm_part(part)
......@@ -855,7 +843,7 @@ def _parse_chat_message_content_part(
return {'type': 'audio'} if wrap_dicts else None
if part_type == "input_audio":
dict_content = cast(Dict[str, str], content)
dict_content = cast(InputAudio, content)
mm_parser.parse_input_audio(dict_content)
return {'type': 'audio'} if wrap_dicts else None
......@@ -1000,14 +988,14 @@ def apply_mistral_chat_template(
**kwargs: Any,
) -> List[int]:
if chat_template is not None:
print_warning_once(
logger.warning_once(
"'chat_template' cannot be overridden for mistral tokenizer.")
if "add_generation_prompt" in kwargs:
print_warning_once(
logger.warning_once(
"'add_generation_prompt' is not supported for mistral tokenizer, "
"so it will be ignored.")
if "continue_final_message" in kwargs:
print_warning_once(
logger.warning_once(
"'continue_final_message' is not supported for mistral tokenizer, "
"so it will be ignored.")
......
import itertools
import warnings
from contextlib import contextmanager
from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple, Type,
Union, cast, overload)
from typing import (Any, Callable, ClassVar, Dict, List, Optional, Sequence,
Tuple, Type, Union, cast, overload)
import cloudpickle
import torch
import torch.nn as nn
from tqdm import tqdm
from typing_extensions import deprecated
from typing_extensions import TypeVar, deprecated
from vllm import envs
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
......@@ -21,7 +24,7 @@ from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
parse_chat_messages,
resolve_chat_template_content_format)
from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt
from vllm.inputs.parse import parse_and_batch_prompt
from vllm.inputs.parse import is_token_prompt, parse_and_batch_prompt
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.guided_decoding.guided_fields import (
......@@ -41,6 +44,8 @@ from vllm.utils import Counter, deprecate_args, deprecate_kwargs, is_list_of
logger = init_logger(__name__)
_R = TypeVar("_R", default=Any)
class LLM:
"""An LLM for generating texts from given prompts and sampling parameters.
......@@ -186,6 +191,13 @@ class LLM:
if "disable_log_stats" not in kwargs:
kwargs["disable_log_stats"] = True
if "worker_cls" in kwargs:
worker_cls = kwargs["worker_cls"]
# if the worker_cls is not qualified string name,
# we serialize it using cloudpickle to avoid pickling issues
if isinstance(worker_cls, type):
kwargs["worker_cls"] = cloudpickle.dumps(worker_cls)
if compilation_config is not None:
if isinstance(compilation_config, (int, dict)):
compilation_config_instance = CompilationConfig.from_cli(
......@@ -225,18 +237,11 @@ class LLM:
# Logic to switch between engines is done at runtime instead of import
# to avoid import order issues
self.engine_class = self.get_engine_class()
# TODO(rob): enable mp by default (issue with fork vs spawn)
self.llm_engine = self.engine_class.from_engine_args(
engine_args, usage_context=UsageContext.LLM_CLASS)
self.request_counter = Counter()
def __del__(self):
if hasattr(self, 'llm_engine') and self.llm_engine and hasattr(
self.llm_engine, "shutdown"):
self.llm_engine.shutdown()
@staticmethod
def get_engine_class() -> Type[LLMEngine]:
if envs.VLLM_USE_V1:
......@@ -462,9 +467,47 @@ class LLM:
outputs = self._run_engine(use_tqdm=use_tqdm)
return self.engine_class.validate_outputs(outputs, RequestOutput)
def collective_rpc(self,
method: Union[str, Callable[..., _R]],
timeout: Optional[float] = None,
args: Tuple = (),
kwargs: Optional[Dict[str, Any]] = None) -> List[_R]:
"""
Execute an RPC call on all workers.
Args:
method: Name of the worker method to execute, or a callable that
is serialized and sent to all workers to execute.
If the method is a callable, it should accept an additional
`self` argument, in addition to the arguments passed in `args`
and `kwargs`. The `self` argument will be the worker object.
timeout: Maximum time in seconds to wait for execution. Raises a
:exc:`TimeoutError` on timeout. `None` means wait indefinitely.
args: Positional arguments to pass to the worker method.
kwargs: Keyword arguments to pass to the worker method.
Returns:
A list containing the results from each worker.
Note:
It is recommended to use this API to only pass control messages,
and set up data-plane communication to pass data.
"""
executor = self.llm_engine.model_executor
return executor.collective_rpc(method, timeout, args, kwargs)
def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
"""
Run a function directly on the model inside each worker,
returning the result for each of them.
"""
executor = self.llm_engine.model_executor
return executor.apply_model(func)
def beam_search(
self,
prompts: List[Union[str, List[int]]],
prompts: List[Union[TokensPrompt, TextPrompt]],
params: BeamSearchParams,
) -> List[BeamSearchOutput]:
"""
......@@ -500,8 +543,10 @@ class LLM:
instances: List[BeamSearchInstance] = []
for prompt in prompts:
prompt_tokens = prompt if isinstance(
prompt, list) else tokenizer.encode(prompt)
if is_token_prompt(prompt):
prompt_tokens = prompt["prompt_token_ids"]
else:
prompt_tokens = tokenizer.encode(prompt["prompt"])
instances.append(BeamSearchInstance(prompt_tokens))
for _ in range(max_tokens):
......@@ -952,6 +997,107 @@ class LLM:
return [ClassificationRequestOutput.from_base(item) for item in items]
def _embedding_score(
self,
tokenizer: AnyTokenizer,
text_1: List[Union[str, TextPrompt, TokensPrompt]],
text_2: List[Union[str, TextPrompt, TokensPrompt]],
truncate_prompt_tokens: Optional[int] = None,
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> List[ScoringRequestOutput]:
encoded_output = self.encode(
text_1 + text_2,
use_tqdm=use_tqdm,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
encoded_output_1 = encoded_output[0:len(text_1)]
encoded_output_2 = encoded_output[len(text_1):]
if len(encoded_output_1) == 1:
encoded_output_1 = encoded_output_1 * len(encoded_output_2)
output_pairs = [(t1, t2)
for t1, t2 in zip(encoded_output_1, encoded_output_2)]
scores = []
scorer = torch.nn.CosineSimilarity(0)
for embed_1, embed_2 in output_pairs:
pair_score = scorer(embed_1.outputs.data, embed_2.outputs.data)
if (pad_token_id := getattr(tokenizer, "pad_token_id",
None)) is not None:
tokens = embed_1.prompt_token_ids + [
pad_token_id
] + embed_2.prompt_token_ids
else:
tokens = embed_1.prompt_token_ids + embed_2.prompt_token_ids
scores.append(
PoolingRequestOutput(
request_id=f"{embed_1.request_id}_{embed_2.request_id}",
outputs=pair_score,
prompt_token_ids=tokens,
finished=True))
items = self.engine_class.validate_outputs(scores,
PoolingRequestOutput)
return [ScoringRequestOutput.from_base(item) for item in items]
def _cross_encoding_score(
self,
tokenizer: Union[AnyTokenizer],
text_1: List[Union[str, TextPrompt, TokensPrompt]],
text_2: List[Union[str, TextPrompt, TokensPrompt]],
truncate_prompt_tokens: Optional[int] = None,
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> List[ScoringRequestOutput]:
if isinstance(tokenizer, MistralTokenizer):
raise ValueError(
"Score API is only enabled for `--task embed or score`")
if len(text_1) == 1:
text_1 = text_1 * len(text_2)
input_pairs = [(t1, t2) for t1, t2 in zip(text_1, text_2)]
pooling_params = PoolingParams()
tokenization_kwargs: Dict[str, Any] = {}
if truncate_prompt_tokens is not None:
tokenization_kwargs["truncation"] = True
tokenization_kwargs["max_length"] = truncate_prompt_tokens
parsed_prompts = []
for q, t in input_pairs:
prompt_inputs = tokenizer(text=q,
text_pair=t,
**tokenization_kwargs)
engine_prompt = TokensPrompt(
prompt_token_ids=prompt_inputs["input_ids"],
token_type_ids=prompt_inputs.get("token_type_ids"))
parsed_prompts.append(engine_prompt)
self._validate_and_add_requests(
prompts=parsed_prompts,
params=pooling_params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
)
outputs = self._run_engine(use_tqdm=use_tqdm)
items = self.engine_class.validate_outputs(outputs,
PoolingRequestOutput)
return [ScoringRequestOutput.from_base(item) for item in items]
def score(
self,
text_1: Union[SingletonPrompt, Sequence[SingletonPrompt]],
......@@ -1003,25 +1149,20 @@ class LLM:
raise ValueError(" ".join(messages))
if not self.llm_engine.model_config.is_cross_encoder:
raise ValueError("Your model does not support cross encoding")
if self.llm_engine.model_config.task != "score":
raise ValueError("Score API is only enabled for `--task score`")
tokenizer = self.llm_engine.get_tokenizer()
if isinstance(tokenizer, MistralTokenizer):
if self.llm_engine.model_config.task not in ("embed", "score"):
raise ValueError(
"MistralTokenizer not supported for cross-encoding")
"Score API is only enabled for `--task embed or --task score`")
# the tokenizer for models such as
# "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing
# lists of tokens to the `text` and `text_pair` kwargs
tokenizer = self.llm_engine.get_tokenizer()
def ensure_str(prompt: SingletonPrompt):
if isinstance(prompt, dict):
if "multi_modal_data" in prompt:
raise ValueError("Multi-modal prompt is not "
"supported for cross encoding")
"supported for scoring")
elif "prompt_token_ids" in prompt:
prompt = tokenizer.decode(
cast(TokensPrompt, prompt)["prompt_token_ids"])
......@@ -1047,40 +1188,15 @@ class LLM:
if len(text_2) == 0:
raise ValueError("At least one text_pair element must be given")
if len(text_1) == 1:
text_1 = text_1 * len(text_2)
input_pairs = [(t1, t2) for t1, t2 in zip(text_1, text_2)]
pooling_params = PoolingParams()
tokenization_kwargs: Dict[str, Any] = {}
if truncate_prompt_tokens is not None:
tokenization_kwargs["truncation"] = True
tokenization_kwargs["max_length"] = truncate_prompt_tokens
parsed_prompts = []
for q, t in input_pairs:
prompt_inputs = tokenizer(text=q,
text_pair=t,
**tokenization_kwargs)
engine_prompt = TokensPrompt(
prompt_token_ids=prompt_inputs["input_ids"],
token_type_ids=prompt_inputs.get("token_type_ids"))
parsed_prompts.append(engine_prompt)
self._validate_and_add_requests(
prompts=parsed_prompts,
params=pooling_params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
)
outputs = self._run_engine(use_tqdm=use_tqdm)
items = self.engine_class.validate_outputs(outputs,
PoolingRequestOutput)
return [ScoringRequestOutput.from_base(item) for item in items]
if self.llm_engine.model_config.is_cross_encoder:
return self._cross_encoding_score(tokenizer, text_1, text_2,
truncate_prompt_tokens, use_tqdm,
lora_request,
prompt_adapter_request)
else:
return self._embedding_score(tokenizer, text_1, text_2,
truncate_prompt_tokens, use_tqdm,
lora_request, prompt_adapter_request)
def start_profile(self) -> None:
self.llm_engine.start_profile()
......@@ -1088,6 +1204,36 @@ class LLM:
def stop_profile(self) -> None:
self.llm_engine.stop_profile()
def reset_prefix_cache(self) -> bool:
return self.llm_engine.reset_prefix_cache()
def sleep(self, level: int = 1):
"""
Put the engine to sleep. The engine should not process any requests.
The caller should guarantee that no requests are being processed
during the sleep period, before `wake_up` is called.
:param level: The sleep level. Level 1 sleep will offload the model
weights and discard the kv cache. The content of kv cache is
forgotten. Level 1 sleep is good for sleeping and waking up the
engine to run the same model again. The model weights are backed
up in CPU memory. Please make sure there's enough CPU memory to
store the model weights. Level 2 sleep will discard both the model
weights and the kv cache. The content of both the model weights
and kv cache is forgotten. Level 2 sleep is good for sleeping and
waking up the engine to run a different model or update the model,
where previous model weights are not needed. It reduces CPU memory
pressure.
"""
self.reset_prefix_cache()
self.llm_engine.sleep(level=level)
def wake_up(self):
"""
Wake up the engine from sleep mode. See the :meth:`sleep` method
for more details."""
self.llm_engine.wake_up()
# LEGACY
def _convert_v1_inputs(
self,
......
import asyncio
import atexit
import gc
import importlib
import inspect
import multiprocessing
......@@ -7,16 +8,17 @@ import os
import re
import signal
import socket
import sys
import tempfile
import uuid
from argparse import Namespace
from contextlib import asynccontextmanager
from functools import partial
from http import HTTPStatus
from typing import AsyncIterator, Optional, Set, Tuple
from typing import AsyncIterator, Dict, Optional, Set, Tuple, Union
import uvloop
from fastapi import APIRouter, FastAPI, Request
from fastapi import APIRouter, FastAPI, HTTPException, Request
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, Response, StreamingResponse
......@@ -44,22 +46,31 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
CompletionResponse,
DetokenizeRequest,
DetokenizeResponse,
EmbeddingChatRequest,
EmbeddingCompletionRequest,
EmbeddingRequest,
EmbeddingResponse,
EmbeddingResponseData,
ErrorResponse,
LoadLoraAdapterRequest,
PoolingChatRequest,
PoolingCompletionRequest,
PoolingRequest, PoolingResponse,
RerankRequest, RerankResponse,
ScoreRequest, ScoreResponse,
TokenizeRequest,
TokenizeResponse,
UnloadLoraAdapterRequest)
from vllm.entrypoints.openai.reasoning_parsers import ReasoningParserManager
# yapf: enable
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import (BaseModelPath,
OpenAIServingModels)
from vllm.entrypoints.openai.serving_pooling import OpenAIServingPooling
from vllm.entrypoints.openai.serving_rerank import JinaAIServingRerank
from vllm.entrypoints.openai.serving_score import OpenAIServingScores
from vllm.entrypoints.openai.serving_tokenization import (
OpenAIServingTokenization)
......@@ -97,6 +108,11 @@ async def lifespan(app: FastAPI):
task.add_done_callback(_running_tasks.remove)
else:
task = None
# Mark the startup heap as static so that it's ignored by GC.
# Reduces pause times of oldest generation collections.
gc.collect()
gc.freeze()
try:
yield
finally:
......@@ -133,32 +149,21 @@ async def build_async_engine_client_from_engine_args(
Returns the Client or None if the creation failed.
"""
# Fall back
# TODO: fill out feature matrix.
# AsyncLLMEngine.
if (MQLLMEngineClient.is_unsupported_config(engine_args)
or envs.VLLM_USE_V1 or disable_frontend_multiprocessing):
engine_config = engine_args.create_engine_config(
UsageContext.OPENAI_API_SERVER)
uses_ray = getattr(AsyncLLMEngine._get_executor_cls(engine_config),
"uses_ray", False)
build_engine = partial(AsyncLLMEngine.from_engine_args,
engine_client: Optional[EngineClient] = None
try:
engine_client = AsyncLLMEngine.from_engine_args(
engine_args=engine_args,
engine_config=engine_config,
usage_context=UsageContext.OPENAI_API_SERVER)
if uses_ray:
# Must run in main thread with ray for its signal handlers to work
engine_client = build_engine()
else:
engine_client = await asyncio.get_running_loop().run_in_executor(
None, build_engine)
yield engine_client
if hasattr(engine_client, "shutdown"):
finally:
if engine_client and hasattr(engine_client, "shutdown"):
engine_client.shutdown()
return
# Otherwise, use the multiprocessing AsyncLLMEngine.
# MQLLMEngine.
else:
if "PROMETHEUS_MULTIPROC_DIR" not in os.environ:
# Make TemporaryDirectory for prometheus multiprocessing
......@@ -280,6 +285,10 @@ def base(request: Request) -> OpenAIServing:
return tokenization(request)
def models(request: Request) -> OpenAIServingModels:
return request.app.state.openai_serving_models
def chat(request: Request) -> Optional[OpenAIServingChat]:
return request.app.state.openai_serving_chat
......@@ -300,6 +309,10 @@ def score(request: Request) -> Optional[OpenAIServingScores]:
return request.app.state.openai_serving_scores
def rerank(request: Request) -> Optional[JinaAIServingRerank]:
return request.app.state.jinaai_serving_reranking
def tokenization(request: Request) -> OpenAIServingTokenization:
return request.app.state.openai_serving_tokenization
......@@ -315,6 +328,12 @@ async def health(raw_request: Request) -> Response:
return Response(status_code=200)
@router.api_route("/ping", methods=["GET", "POST"])
async def ping(raw_request: Request) -> Response:
"""Ping check. Endpoint required for SageMaker"""
return await health(raw_request)
@router.post("/tokenize")
@with_cancellation
async def tokenize(request: TokenizeRequest, raw_request: Request):
......@@ -347,10 +366,10 @@ async def detokenize(request: DetokenizeRequest, raw_request: Request):
@router.get("/v1/models")
async def show_available_models(raw_request: Request):
handler = base(raw_request)
handler = models(raw_request)
models = await handler.show_available_models()
return JSONResponse(content=models.model_dump())
models_ = await handler.show_available_models()
return JSONResponse(content=models_.model_dump())
@router.get("/version")
......@@ -414,6 +433,8 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
"use the Pooling API (`/pooling`) instead.")
res = await fallback_handler.create_pooling(request, raw_request)
generator: Union[ErrorResponse, EmbeddingResponse]
if isinstance(res, PoolingResponse):
generator = EmbeddingResponse(
id=res.id,
......@@ -488,6 +509,103 @@ async def create_score_v1(request: ScoreRequest, raw_request: Request):
return await create_score(request, raw_request)
@router.post("/rerank")
@with_cancellation
async def do_rerank(request: RerankRequest, raw_request: Request):
handler = rerank(raw_request)
if handler is None:
return base(raw_request).create_error_response(
message="The model does not support Rerank (Score) API")
generator = await handler.do_rerank(request, raw_request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
elif isinstance(generator, RerankResponse):
return JSONResponse(content=generator.model_dump())
assert_never(generator)
@router.post("/v1/rerank")
@with_cancellation
async def do_rerank_v1(request: RerankRequest, raw_request: Request):
logger.warning_once(
"To indicate that the rerank API is not part of the standard OpenAI"
" API, we have located it at `/rerank`. Please update your client"
"accordingly. (Note: Conforms to JinaAI rerank API)")
return await do_rerank(request, raw_request)
@router.post("/v2/rerank")
@with_cancellation
async def do_rerank_v2(request: RerankRequest, raw_request: Request):
return await do_rerank(request, raw_request)
TASK_HANDLERS: Dict[str, Dict[str, tuple]] = {
"generate": {
"messages": (ChatCompletionRequest, create_chat_completion),
"default": (CompletionRequest, create_completion),
},
"embed": {
"messages": (EmbeddingChatRequest, create_embedding),
"default": (EmbeddingCompletionRequest, create_embedding),
},
"score": {
"default": (RerankRequest, do_rerank)
},
"rerank": {
"default": (RerankRequest, do_rerank)
},
"reward": {
"messages": (PoolingChatRequest, create_pooling),
"default": (PoolingCompletionRequest, create_pooling),
},
"classify": {
"messages": (PoolingChatRequest, create_pooling),
"default": (PoolingCompletionRequest, create_pooling),
},
}
if envs.VLLM_SERVER_DEV_MODE:
@router.post("/reset_prefix_cache")
async def reset_prefix_cache(raw_request: Request):
"""
Reset the prefix cache. Note that we currently do not check if the
prefix cache is successfully reset in the API server.
"""
logger.info("Resetting prefix cache...")
await engine_client(raw_request).reset_prefix_cache()
return Response(status_code=200)
@router.post("/invocations")
async def invocations(raw_request: Request):
"""
For SageMaker, routes requests to other handlers based on model `task`.
"""
body = await raw_request.json()
task = raw_request.app.state.task
if task not in TASK_HANDLERS:
raise HTTPException(
status_code=400,
detail=f"Unsupported task: '{task}' for '/invocations'. "
f"Expected one of {set(TASK_HANDLERS.keys())}")
handler_config = TASK_HANDLERS[task]
if "messages" in body:
request_model, handler = handler_config["messages"]
else:
request_model, handler = handler_config["default"]
# this is required since we lose the FastAPI automatic casting
request = request_model.model_validate(body)
return await handler(request, raw_request)
if envs.VLLM_TORCH_PROFILER_DIR:
logger.warning(
"Torch Profiler is enabled in the API server. This should ONLY be "
......@@ -516,9 +634,7 @@ if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
@router.post("/v1/load_lora_adapter")
async def load_lora_adapter(request: LoadLoraAdapterRequest,
raw_request: Request):
for route in [chat, completion, embedding]:
handler = route(raw_request)
if handler is not None:
handler = models(raw_request)
response = await handler.load_lora_adapter(request)
if isinstance(response, ErrorResponse):
return JSONResponse(content=response.model_dump(),
......@@ -529,9 +645,7 @@ if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
@router.post("/v1/unload_lora_adapter")
async def unload_lora_adapter(request: UnloadLoraAdapterRequest,
raw_request: Request):
for route in [chat, completion, embedding]:
handler = route(raw_request)
if handler is not None:
handler = models(raw_request)
response = await handler.unload_lora_adapter(request)
if isinstance(response, ErrorResponse):
return JSONResponse(content=response.model_dump(),
......@@ -602,7 +716,7 @@ def build_app(args: Namespace) -> FastAPI:
module_path, object_name = middleware.rsplit(".", 1)
imported = getattr(importlib.import_module(module_path), object_name)
if inspect.isclass(imported):
app.add_middleware(imported)
app.add_middleware(imported) # type: ignore[arg-type]
elif inspect.iscoroutinefunction(imported):
app.middleware("http")(imported)
else:
......@@ -612,7 +726,7 @@ def build_app(args: Namespace) -> FastAPI:
return app
def init_app_state(
async def init_app_state(
engine_client: EngineClient,
model_config: ModelConfig,
state: State,
......@@ -639,34 +753,40 @@ def init_app_state(
resolved_chat_template = load_chat_template(args.chat_template)
logger.info("Using supplied chat template:\n%s", resolved_chat_template)
state.openai_serving_models = OpenAIServingModels(
engine_client=engine_client,
model_config=model_config,
base_model_paths=base_model_paths,
lora_modules=args.lora_modules,
prompt_adapters=args.prompt_adapters,
)
await state.openai_serving_models.init_static_loras()
state.openai_serving_chat = OpenAIServingChat(
engine_client,
model_config,
base_model_paths,
state.openai_serving_models,
args.response_role,
lora_modules=args.lora_modules,
prompt_adapters=args.prompt_adapters,
request_logger=request_logger,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
enable_auto_tools=args.enable_auto_tool_choice,
tool_parser=args.tool_call_parser,
enable_reasoning=args.enable_reasoning,
reasoning_parser=args.reasoning_parser,
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
) if model_config.runner_type == "generate" else None
state.openai_serving_completion = OpenAIServingCompletion(
engine_client,
model_config,
base_model_paths,
lora_modules=args.lora_modules,
prompt_adapters=args.prompt_adapters,
state.openai_serving_models,
request_logger=request_logger,
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
) if model_config.runner_type == "generate" else None
state.openai_serving_pooling = OpenAIServingPooling(
engine_client,
model_config,
base_model_paths,
state.openai_serving_models,
request_logger=request_logger,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
......@@ -674,7 +794,7 @@ def init_app_state(
state.openai_serving_embedding = OpenAIServingEmbedding(
engine_client,
model_config,
base_model_paths,
state.openai_serving_models,
request_logger=request_logger,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
......@@ -682,18 +802,24 @@ def init_app_state(
state.openai_serving_scores = OpenAIServingScores(
engine_client,
model_config,
base_model_paths,
state.openai_serving_models,
request_logger=request_logger
) if model_config.task == "score" else None
state.jinaai_serving_reranking = JinaAIServingRerank(
engine_client,
model_config,
state.openai_serving_models,
request_logger=request_logger
) if model_config.task == "score" else None
state.openai_serving_tokenization = OpenAIServingTokenization(
engine_client,
model_config,
base_model_paths,
lora_modules=args.lora_modules,
state.openai_serving_models,
request_logger=request_logger,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
)
state.task = model_config.task
def create_server_socket(addr: Tuple[str, int]) -> socket.socket:
......@@ -715,11 +841,18 @@ async def run_server(args, **uvicorn_kwargs) -> None:
if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
ToolParserManager.import_tool_parser(args.tool_parser_plugin)
valide_tool_parses = ToolParserManager.tool_parsers.keys()
valid_tool_parses = ToolParserManager.tool_parsers.keys()
if args.enable_auto_tool_choice \
and args.tool_call_parser not in valide_tool_parses:
and args.tool_call_parser not in valid_tool_parses:
raise KeyError(f"invalid tool call parser: {args.tool_call_parser} "
f"(chose from {{ {','.join(valide_tool_parses)} }})")
f"(chose from {{ {','.join(valid_tool_parses)} }})")
valid_reasoning_parses = ReasoningParserManager.reasoning_parsers.keys()
if args.enable_reasoning \
and args.reasoning_parser not in valid_reasoning_parses:
raise KeyError(
f"invalid reasoning parser: {args.reasoning_parser} "
f"(chose from {{ {','.join(valid_reasoning_parses)} }})")
# workaround to make sure that we bind the port before the engine is set up.
# This avoids race conditions with ray.
......@@ -741,7 +874,7 @@ async def run_server(args, **uvicorn_kwargs) -> None:
app = build_app(args)
model_config = await engine_client.get_model_config()
init_app_state(engine_client, model_config, app.state, args)
await init_app_state(engine_client, model_config, app.state, args)
shutdown_task = await serve_http(
app,
......@@ -753,6 +886,8 @@ async def run_server(args, **uvicorn_kwargs) -> None:
ssl_certfile=args.ssl_certfile,
ssl_ca_certs=args.ssl_ca_certs,
ssl_cert_reqs=args.ssl_cert_reqs,
# Workaround to work on macOS
fd=sock.fileno() if sys.platform.startswith("darwin") else None,
**uvicorn_kwargs,
)
......
......@@ -12,7 +12,8 @@ from typing import List, Optional, Sequence, Union, get_args
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption,
validate_chat_template)
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
from vllm.entrypoints.openai.reasoning_parsers import ReasoningParserManager
from vllm.entrypoints.openai.serving_models import (LoRAModulePath,
PromptAdapterPath)
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
from vllm.utils import FlexibleArgumentParser
......@@ -79,29 +80,29 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
parser.add_argument("--host",
type=nullable_str,
default=None,
help="host name")
parser.add_argument("--port", type=int, default=8000, help="port number")
help="Host name.")
parser.add_argument("--port", type=int, default=8000, help="Port number.")
parser.add_argument(
"--uvicorn-log-level",
type=str,
default="info",
choices=['debug', 'info', 'warning', 'error', 'critical', 'trace'],
help="log level for uvicorn")
help="Log level for uvicorn.")
parser.add_argument("--allow-credentials",
action="store_true",
help="allow credentials")
help="Allow credentials.")
parser.add_argument("--allowed-origins",
type=json.loads,
default=["*"],
help="allowed origins")
help="Allowed origins.")
parser.add_argument("--allowed-methods",
type=json.loads,
default=["*"],
help="allowed methods")
help="Allowed methods.")
parser.add_argument("--allowed-headers",
type=json.loads,
default=["*"],
help="allowed headers")
help="Allowed headers.")
parser.add_argument("--api-key",
type=nullable_str,
default=None,
......@@ -115,10 +116,10 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
action=LoRAParserAction,
help="LoRA module configurations in either 'name=path' format"
"or JSON format. "
"Example (old format): 'name=path' "
"Example (old format): ``'name=path'`` "
"Example (new format): "
"'{\"name\": \"name\", \"local_path\": \"path\", "
"\"base_model_name\": \"id\"}'")
"``{\"name\": \"name\", \"path\": \"lora_path\", "
"\"base_model_name\": \"id\"}``")
parser.add_argument(
"--prompt-adapters",
type=nullable_str,
......@@ -132,7 +133,7 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
default=None,
help="The file path to the chat template, "
"or the template in single-line form "
"for the specified model")
"for the specified model.")
parser.add_argument(
'--chat-template-content-format',
type=str,
......@@ -141,38 +142,39 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
help='The format to render message content within a chat template.'
'\n\n'
'* "string" will render the content as a string. '
'Example: "Hello World"\n'
'Example: ``"Hello World"``\n'
'* "openai" will render the content as a list of dictionaries, '
'similar to OpenAI schema. '
'Example: [{"type": "text", "text": "Hello world!"}]')
'Example: ``[{"type": "text", "text": "Hello world!"}]``')
parser.add_argument("--response-role",
type=nullable_str,
default="assistant",
help="The role name to return if "
"`request.add_generation_prompt=true`.")
"``request.add_generation_prompt=true``.")
parser.add_argument("--ssl-keyfile",
type=nullable_str,
default=None,
help="The file path to the SSL key file")
help="The file path to the SSL key file.")
parser.add_argument("--ssl-certfile",
type=nullable_str,
default=None,
help="The file path to the SSL cert file")
help="The file path to the SSL cert file.")
parser.add_argument("--ssl-ca-certs",
type=nullable_str,
default=None,
help="The CA certificates file")
help="The CA certificates file.")
parser.add_argument(
"--ssl-cert-reqs",
type=int,
default=int(ssl.CERT_NONE),
help="Whether client certificate is required (see stdlib ssl module's)"
help="Whether client certificate is required (see stdlib ssl module's)."
)
parser.add_argument(
"--root-path",
type=nullable_str,
default=None,
help="FastAPI root_path when app is behind a path based routing proxy")
help="FastAPI root_path when app is behind a path based routing proxy."
)
parser.add_argument(
"--middleware",
type=nullable_str,
......@@ -182,15 +184,15 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
"We accept multiple --middleware arguments. "
"The value should be an import path. "
"If a function is provided, vLLM will add it to the server "
"using @app.middleware('http'). "
"using ``@app.middleware('http')``. "
"If a class is provided, vLLM will add it to the server "
"using app.add_middleware(). ")
"using ``app.add_middleware()``. ")
parser.add_argument(
"--return-tokens-as-token-ids",
action="store_true",
help="When --max-logprobs is specified, represents single tokens as "
"strings of the form 'token_id:{token_id}' so that tokens that "
"are not JSON-encodable can be identified.")
help="When ``--max-logprobs`` is specified, represents single tokens "
" as strings of the form 'token_id:{token_id}' so that tokens "
"that are not JSON-encodable can be identified.")
parser.add_argument(
"--disable-frontend-multiprocessing",
action="store_true",
......@@ -205,9 +207,25 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
"--enable-auto-tool-choice",
action="store_true",
default=False,
help="Enable auto tool choice for supported models. Use "
"``--tool-call-parser`` to specify which parser to use.")
parser.add_argument(
"--enable-reasoning",
action="store_true",
default=False,
help="Whether to enable reasoning_content for the model. "
"If enabled, the model will be able to generate reasoning content.")
valid_reasoning_parsers = ReasoningParserManager.reasoning_parsers.keys()
parser.add_argument(
"--reasoning-parser",
type=str,
metavar="{" + ",".join(valid_reasoning_parsers) + "}",
default=None,
help=
"Enable auto tool choice for supported models. Use --tool-call-parser"
" to specify which parser to use")
"Select the reasoning parser depending on the model that you're using."
" This is used to parse the reasoning content into OpenAI API "
"format. Required for ``--enable-reasoning``.")
valid_tool_parsers = ToolParserManager.tool_parsers.keys()
parser.add_argument(
......@@ -219,7 +237,7 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
help=
"Select the tool call parser depending on the model that you're using."
" This is used to parse the model-generated tool call into OpenAI API "
"format. Required for --enable-auto-tool-choice.")
"format. Required for ``--enable-auto-tool-choice``.")
parser.add_argument(
"--tool-parser-plugin",
......@@ -228,7 +246,7 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
help=
"Special the tool parser plugin write to parse the model-generated tool"
" into OpenAI API format, the name register in this plugin can be used "
"in --tool-call-parser.")
"in ``--tool-call-parser``.")
parser = AsyncEngineArgs.add_cli_args(parser)
......@@ -243,7 +261,7 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
"--disable-fastapi-docs",
action='store_true',
default=False,
help="Disable FastAPI's OpenAPI schema, Swagger UI, and ReDoc endpoint"
help="Disable FastAPI's OpenAPI schema, Swagger UI, and ReDoc endpoint."
)
parser.add_argument(
"--enable-prompt-tokens-details",
......@@ -267,6 +285,18 @@ def validate_parsed_serve_args(args: argparse.Namespace):
raise TypeError("Error: --enable-auto-tool-choice requires "
"--tool-call-parser")
# Enable reasoning needs a reasoning parser to be valid
if args.enable_reasoning and not args.reasoning_parser:
raise TypeError("Error: --enable-reasoning requires "
"--reasoning-parser")
# Ref https://api-docs.deepseek.com/guides/reasoning_model
# tool call and reasoning cannot be enabled at the same time.
if args.enable_auto_tool_choice and args.enable_reasoning:
raise TypeError(
"Error: --enable-auto-tool-choice and "
"--enable-reasoning cannot be enabled at the same time")
def create_parser_for_docs() -> FlexibleArgumentParser:
parser_for_docs = FlexibleArgumentParser(
......
......@@ -3,10 +3,11 @@
import re
import time
from argparse import Namespace
from typing import Any, Dict, List, Literal, Optional, Union
from typing import Any, ClassVar, Dict, List, Literal, Optional, Set, Union
import torch
from pydantic import BaseModel, ConfigDict, Field, model_validator
from pydantic import (BaseModel, ConfigDict, Field, TypeAdapter,
ValidationInfo, field_validator, model_validator)
from typing_extensions import Annotated
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
......@@ -42,24 +43,32 @@ class OpenAIBaseModel(BaseModel):
# OpenAI API does allow extra fields
model_config = ConfigDict(extra="allow")
@model_validator(mode="before")
# Cache class field names
field_names: ClassVar[Optional[Set[str]]] = None
@model_validator(mode="wrap")
@classmethod
def __log_extra_fields__(cls, data):
if isinstance(data, dict):
def __log_extra_fields__(cls, data, handler):
result = handler(data)
if not isinstance(data, dict):
return result
field_names = cls.field_names
if field_names is None:
# Get all class field names and their potential aliases
field_names = set()
for field_name, field in cls.model_fields.items():
field_names.add(field_name)
if hasattr(field, 'alias') and field.alias:
field_names.add(field.alias)
if alias := getattr(field, 'alias', None):
field_names.add(alias)
cls.field_names = field_names
# Compare against both field names and aliases
extra_fields = data.keys() - field_names
if extra_fields:
if any(k not in field_names for k in data):
logger.warning(
"The following fields were present in the request "
"but ignored: %s", extra_fields)
return data
"but ignored: %s",
data.keys() - field_names)
return result
class ErrorResponse(OpenAIBaseModel):
......@@ -372,13 +381,17 @@ class ChatCompletionRequest(OpenAIBaseModel):
) -> BeamSearchParams:
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
max_tokens = self.max_completion_tokens or self.max_tokens
if max_tokens is None:
max_tokens = default_max_tokens
if default_sampling_params is None:
default_sampling_params = {}
n = self.n if self.n is not None else 1
# Use minimum of context window, user request & server limit.
max_tokens = min(
val for val in (default_max_tokens, max_tokens,
default_sampling_params.get("max_tokens", None))
if val is not None)
if (temperature := self.temperature) is None:
temperature = default_sampling_params.get(
"temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"])
......@@ -398,11 +411,16 @@ class ChatCompletionRequest(OpenAIBaseModel):
default_sampling_params: Optional[dict] = None) -> SamplingParams:
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
max_tokens = self.max_completion_tokens or self.max_tokens
if max_tokens is None:
max_tokens = default_max_tokens
if default_sampling_params is None:
default_sampling_params = {}
# Use minimum of context window, user request & server limit.
max_tokens = min(
val for val in (default_max_tokens, max_tokens,
default_sampling_params.get("max_tokens", None))
if val is not None)
# Default parameters
if (repetition_penalty := self.repetition_penalty) is None:
repetition_penalty = default_sampling_params.get(
......@@ -732,13 +750,17 @@ class CompletionRequest(OpenAIBaseModel):
default_sampling_params: Optional[dict] = None
) -> BeamSearchParams:
max_tokens = self.max_tokens
if max_tokens is None:
max_tokens = default_max_tokens
if default_sampling_params is None:
default_sampling_params = {}
n = self.n if self.n is not None else 1
# Use minimum of context window, user request & server limit.
max_tokens = min(
val for val in (default_max_tokens, max_tokens,
default_sampling_params.get("max_tokens", None))
if val is not None)
if (temperature := self.temperature) is None:
temperature = default_sampling_params.get("temperature", 1.0)
......@@ -756,11 +778,16 @@ class CompletionRequest(OpenAIBaseModel):
logits_processor_pattern: Optional[str],
default_sampling_params: Optional[dict] = None) -> SamplingParams:
max_tokens = self.max_tokens
if max_tokens is None:
max_tokens = default_max_tokens
if default_sampling_params is None:
default_sampling_params = {}
# Use minimum of context window, user request & server limit.
max_tokens = min(
val for val in (default_max_tokens, max_tokens,
default_sampling_params.get("max_tokens", None))
if val is not None)
# Default parameters
if (repetition_penalty := self.repetition_penalty) is None:
repetition_penalty = default_sampling_params.get(
......@@ -992,6 +1019,52 @@ class ScoreRequest(OpenAIBaseModel):
return PoolingParams(additional_data=self.additional_data)
class RerankRequest(OpenAIBaseModel):
model: str
query: str
documents: List[str]
top_n: int = Field(default_factory=lambda: 0)
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
# doc: begin-rerank-pooling-params
additional_data: Optional[Any] = None
# doc: end-rerank-pooling-params
# doc: begin-rerank-extra-params
priority: int = Field(
default=0,
description=(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."))
# doc: end-rerank-extra-params
def to_pooling_params(self):
return PoolingParams(additional_data=self.additional_data)
class RerankDocument(BaseModel):
text: str
class RerankResult(BaseModel):
index: int
document: RerankDocument
relevance_score: float
class RerankUsage(BaseModel):
total_tokens: int
class RerankResponse(OpenAIBaseModel):
id: str
model: str
usage: RerankUsage
results: List[RerankResult]
class CompletionLogProbs(OpenAIBaseModel):
text_offset: List[int] = Field(default_factory=list)
token_logprobs: List[Optional[float]] = Field(default_factory=list)
......@@ -1130,6 +1203,7 @@ class ExtractedToolCallInformation(BaseModel):
class ChatMessage(OpenAIBaseModel):
role: str
reasoning_content: Optional[str] = None
content: Optional[str] = None
tool_calls: List[ToolCall] = Field(default_factory=list)
......@@ -1171,6 +1245,7 @@ class ChatCompletionResponse(OpenAIBaseModel):
class DeltaMessage(OpenAIBaseModel):
role: Optional[str] = None
content: Optional[str] = None
reasoning_content: Optional[str] = None
tool_calls: List[DeltaToolCall] = Field(default_factory=list)
......@@ -1211,7 +1286,21 @@ class BatchRequestInput(OpenAIBaseModel):
url: str
# The parameters of the request.
body: Union[ChatCompletionRequest, EmbeddingRequest]
body: Union[ChatCompletionRequest, EmbeddingRequest, ScoreRequest]
@field_validator('body', mode='plain')
@classmethod
def check_type_for_url(cls, value: Any, info: ValidationInfo):
# Use url to disambiguate models
url = info.data['url']
if url == "/v1/chat/completions":
return ChatCompletionRequest.model_validate(value)
if url == "/v1/embeddings":
return TypeAdapter(EmbeddingRequest).validate_python(value)
if url == "/v1/score":
return ScoreRequest.model_validate(value)
return TypeAdapter(Union[ChatCompletionRequest, EmbeddingRequest,
ScoreRequest]).validate_python(value)
class BatchResponseData(OpenAIBaseModel):
......@@ -1222,7 +1311,8 @@ class BatchResponseData(OpenAIBaseModel):
request_id: str
# The body of the response.
body: Optional[Union[ChatCompletionResponse, EmbeddingResponse]] = None
body: Optional[Union[ChatCompletionResponse, EmbeddingResponse,
ScoreResponse]] = None
class BatchRequestOutput(OpenAIBaseModel):
......
from .abs_reasoning_parsers import ReasoningParser, ReasoningParserManager
from .deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser
__all__ = [
"ReasoningParser", "ReasoningParserManager", "DeepSeekR1ReasoningParser"
]
import os
from functools import cached_property
from typing import Callable, Dict, List, Optional, Sequence, Tuple, Type, Union
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaMessage)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import import_from_path, is_list_of
logger = init_logger(__name__)
class ReasoningParser:
"""
Abstract reasoning parser class that should not be used directly.
Provided and methods should be used in derived classes.
It is used to extract reasoning content from the model output.
"""
def __init__(self, tokenizer: AnyTokenizer):
self.model_tokenizer = tokenizer
@cached_property
def vocab(self) -> Dict[str, int]:
# NOTE: Only PreTrainedTokenizerFast is guaranteed to have .vocab
# whereas all tokenizers have .get_vocab()
return self.model_tokenizer.get_vocab()
def extract_reasoning_content(
self, model_output: str, request: ChatCompletionRequest
) -> Tuple[Optional[str], Optional[str]]:
"""
Extract reasoning content from a complete model-generated string.
Used for non-streaming responses where we have the entire model response
available before sending to the client.
Parameters:
model_output: str
The model-generated string to extract reasoning content from.
request: ChatCompletionRequest
The request object that was used to generate the model_output.
Returns:
Tuple[Optional[str], Optional[str]]
A tuple containing the reasoning content and the content.
"""
raise NotImplementedError(
"AbstractReasoningParser.extract_reasoning_calls "
"has not been implemented!")
def extract_reasoning_content_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
) -> Union[DeltaMessage, None]:
"""
Instance method that should be implemented for extracting reasoning
from an incomplete response; for use when handling reasoning calls and
streaming. Has to be an instance method because it requires state -
the current tokens/diffs, but also the information about what has
previously been parsed and extracted (see constructor)
"""
raise NotImplementedError(
"AbstractReasoningParser.extract_reasoning_content_streaming "
"has not been implemented!")
class ReasoningParserManager:
reasoning_parsers: Dict[str, Type] = {}
@classmethod
def get_reasoning_parser(cls, name) -> Type:
"""
Get reasoning parser by name which is registered by `register_module`.
Raise a KeyError exception if the name is not registered.
"""
if name in cls.reasoning_parsers:
return cls.reasoning_parsers[name]
raise KeyError(f"reasoning helper: '{name}' not found in "
"reasoning_parsers")
@classmethod
def _register_module(cls,
module: Type,
module_name: Optional[Union[str, List[str]]] = None,
force: bool = True) -> None:
if not issubclass(module, ReasoningParser):
raise TypeError("module must be subclass of ReasoningParser, "
f"but got {type(module)}")
if module_name is None:
module_name = module.__name__
if isinstance(module_name, str):
module_name = [module_name]
for name in module_name:
if not force and name in cls.reasoning_parsers:
existed_module = cls.reasoning_parsers[name]
raise KeyError(f"{name} is already registered "
f"at {existed_module.__module__}")
cls.reasoning_parsers[name] = module
@classmethod
def register_module(
cls,
name: Optional[Union[str, List[str]]] = None,
force: bool = True,
module: Union[Type, None] = None) -> Union[type, Callable]:
"""
Register module with the given name or name list. it can be used as a
decoder(with module as None) or normal function(with module as not
None).
"""
if not isinstance(force, bool):
raise TypeError(f"force must be a boolean, but got {type(force)}")
# raise the error ahead of time
if not (name is None or isinstance(name, str)
or is_list_of(name, str)):
raise TypeError(
"name must be None, an instance of str, or a sequence of str, "
f"but got {type(name)}")
# use it as a normal method: x.register_module(module=SomeClass)
if module is not None:
cls._register_module(module=module, module_name=name, force=force)
return module
# use it as a decorator: @x.register_module()
def _register(module):
cls._register_module(module=module, module_name=name, force=force)
return module
return _register
@classmethod
def import_reasoning_parser(cls, plugin_path: str) -> None:
"""
Import a user-defined reasoning parser by the path
of the reasoning parser define file.
"""
module_name = os.path.splitext(os.path.basename(plugin_path))[0]
try:
import_from_path(module_name, plugin_path)
except Exception:
logger.exception("Failed to load module '%s' from %s.",
module_name, plugin_path)
return
import re
from typing import Optional, Sequence, Tuple, Union
from transformers import PreTrainedTokenizerBase
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaMessage)
from vllm.entrypoints.openai.reasoning_parsers.abs_reasoning_parsers import (
ReasoningParser, ReasoningParserManager)
from vllm.logger import init_logger
logger = init_logger(__name__)
@ReasoningParserManager.register_module("deepseek_r1")
class DeepSeekR1ReasoningParser(ReasoningParser):
"""
Reasoning parser for DeepSeek R1 model.
The DeepSeek R1 model uses <think>...</think> tokens to denote reasoning
text. This parser extracts the reasoning content from the model output.
"""
def __init__(self, tokenizer: PreTrainedTokenizerBase):
super().__init__(tokenizer)
self.think_start_token = "<think>"
self.think_end_token = "</think>"
self.reasoning_regex = re.compile(
rf"{self.think_start_token}(.*?){self.think_end_token}", re.DOTALL)
if not self.model_tokenizer:
raise ValueError(
"The model tokenizer must be passed to the ReasoningParser "
"constructor during construction.")
self.think_start_token_id = self.vocab.get(self.think_start_token)
self.think_end_token_id = self.vocab.get(self.think_end_token)
if (self.think_start_token_id is None
or self.think_end_token_id is None):
raise RuntimeError(
"DeepSeek R1 reasoning parser could not locate think start/end "
"tokens in the tokenizer!")
def extract_reasoning_content_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
) -> Union[DeltaMessage, None]:
"""
Extract reasoning content from a delta message.
Handles streaming output where previous + delta = current.
Uses token IDs for faster processing.
For text <think>abc</think>xyz:
- 'abc' goes to reasoning_content
- 'xyz' goes to content
"""
# Skip single special tokens
if len(delta_token_ids) == 1 and (delta_token_ids[0] in [
self.think_start_token_id, self.think_end_token_id
]):
return None
if self.think_start_token_id in previous_token_ids:
if self.think_end_token_id in delta_token_ids:
# <think> in previous, </think> in delta,
# extract reasoning content
end_index = delta_text.find(self.think_end_token)
reasoning_content = delta_text[:end_index]
content = delta_text[end_index + len(self.think_end_token):]
return DeltaMessage(reasoning_content=reasoning_content,
content=content if content else None)
elif self.think_end_token_id in previous_token_ids:
# <think> in previous, </think> in previous,
# reasoning content continues
return DeltaMessage(content=delta_text)
else:
# <think> in previous, no </think> in previous or delta,
# reasoning content continues
return DeltaMessage(reasoning_content=delta_text)
elif self.think_start_token_id in delta_token_ids:
logger.info(delta_text)
if self.think_end_token_id in delta_token_ids:
# <think> in delta, </think> in delta, extract reasoning content
start_index = delta_text.find(self.think_start_token)
end_index = delta_text.find(self.think_end_token)
reasoning_content = delta_text[start_index +
len(self.think_start_token
):end_index]
content = delta_text[end_index + len(self.think_end_token):]
return DeltaMessage(reasoning_content=reasoning_content,
content=content if content else None)
else:
# <think> in delta, no </think> in delta,
# reasoning content continues
return DeltaMessage(reasoning_content=delta_text)
else:
# No <think> in previous or delta, reasoning content continues.
return DeltaMessage(content=delta_text)
def extract_reasoning_content(
self, model_output: str, request: ChatCompletionRequest
) -> Tuple[Optional[str], Optional[str]]:
# Check if the model output contains the <think> tokens.
if (self.think_start_token not in model_output
or self.think_end_token not in model_output):
return None, model_output
else:
# Use a regex to find the reasoning content
reasoning_content = self.reasoning_regex.findall(model_output)[0]
# Remove the reasoning content from the model output
# Although deepseek's <think> token is always at the
# beginning of the line, we cannot guarantee that the
# other models will follow this convention.
# Therefore, we need to add :start_index.
start_index = model_output.find(self.think_start_token)
if start_index != -1:
end_index = start_index + len(
f"{self.think_start_token}{reasoning_content}{self.think_end_token}"
)
model_output = model_output[:start_index] + \
model_output[end_index:]
if len(model_output) == 0:
return reasoning_content, None
return reasoning_content, model_output
......@@ -16,11 +16,14 @@ from vllm.entrypoints.openai.protocol import (BatchRequestInput,
BatchRequestOutput,
BatchResponseData,
ChatCompletionResponse,
EmbeddingResponse, ErrorResponse)
EmbeddingResponse, ErrorResponse,
ScoreResponse)
# yapf: enable
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
from vllm.entrypoints.openai.serving_engine import BaseModelPath
from vllm.entrypoints.openai.serving_models import (BaseModelPath,
OpenAIServingModels)
from vllm.entrypoints.openai.serving_score import OpenAIServingScores
from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser, random_uuid
from vllm.version import __version__ as VLLM_VERSION
......@@ -166,7 +169,8 @@ async def run_request(serving_engine_func: Callable,
tracker: BatchProgressTracker) -> BatchRequestOutput:
response = await serving_engine_func(request.body)
if isinstance(response, (ChatCompletionResponse, EmbeddingResponse)):
if isinstance(response,
(ChatCompletionResponse, EmbeddingResponse, ScoreResponse)):
batch_output = BatchRequestOutput(
id=f"vllm-{random_uuid()}",
custom_id=request.custom_id,
......@@ -213,13 +217,18 @@ async def main(args):
request_logger = RequestLogger(max_log_len=args.max_log_len)
# Create the openai serving objects.
openai_serving_models = OpenAIServingModels(
engine_client=engine,
model_config=model_config,
base_model_paths=base_model_paths,
lora_modules=None,
prompt_adapters=None,
)
openai_serving_chat = OpenAIServingChat(
engine,
model_config,
base_model_paths,
openai_serving_models,
args.response_role,
lora_modules=None,
prompt_adapters=None,
request_logger=request_logger,
chat_template=None,
chat_template_content_format="auto",
......@@ -228,11 +237,17 @@ async def main(args):
openai_serving_embedding = OpenAIServingEmbedding(
engine,
model_config,
base_model_paths,
openai_serving_models,
request_logger=request_logger,
chat_template=None,
chat_template_content_format="auto",
) if model_config.task == "embed" else None
openai_serving_scores = (OpenAIServingScores(
engine,
model_config,
openai_serving_models,
request_logger=request_logger,
) if model_config.task == "score" else None)
tracker = BatchProgressTracker()
logger.info("Reading batch from %s...", args.input_file)
......@@ -273,14 +288,28 @@ async def main(args):
))
continue
response_futures.append(run_request(handler_fn, request, tracker))
tracker.submitted()
elif request.url == "/v1/score":
handler_fn = (None if openai_serving_scores is None else
openai_serving_scores.create_score)
if handler_fn is None:
response_futures.append(
make_async_error_request_output(
request,
error_msg="The model does not support Scores API",
))
continue
response_futures.append(run_request(handler_fn, request, tracker))
tracker.submitted()
else:
response_futures.append(
make_async_error_request_output(
request,
error_msg="Only /v1/chat/completions and "
"/v1/embeddings are supported in the batch endpoint.",
error_msg=
"Only /v1/chat/completions, /v1/embeddings, and /v1/score "
"are supported in the batch endpoint.",
))
with tracker.pbar():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment