Commit 4eabe123 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge remote-tracking branch 'mirror/releases/v0.9.0' into v0.9.0-ori

parents 45840cd2 58738772
...@@ -29,7 +29,7 @@ prometheus_client.disable_created_metrics() ...@@ -29,7 +29,7 @@ prometheus_client.disable_created_metrics()
# to extract the metrics definitions. # to extract the metrics definitions.
# begin-metrics-definitions # --8<-- [start:metrics-definitions]
class Metrics: class Metrics:
""" """
vLLM uses a multiprocessing-based frontend for the OpenAI server. vLLM uses a multiprocessing-based frontend for the OpenAI server.
...@@ -293,7 +293,7 @@ class Metrics: ...@@ -293,7 +293,7 @@ class Metrics:
labelnames=labelnames)) labelnames=labelnames))
# end-metrics-definitions # --8<-- [end:metrics-definitions]
def _unregister_vllm_metrics(self) -> None: def _unregister_vllm_metrics(self) -> None:
for collector in list(prometheus_client.REGISTRY._collector_to_names): for collector in list(prometheus_client.REGISTRY._collector_to_names):
......
...@@ -492,8 +492,9 @@ class MQLLMEngineClient(EngineClient): ...@@ -492,8 +492,9 @@ class MQLLMEngineClient(EngineClient):
from the LLMEngine to the caller. from the LLMEngine to the caller.
Args: Args:
prompt: The prompt to the LLM. See {class}`~vllm.inputs.PromptType` prompt: The prompt to the LLM. See
for more details about the format of each input. [`PromptType`][vllm.inputs.PromptType] for more details about
the format of each input.
sampling_params: The sampling parameters of the request. sampling_params: The sampling parameters of the request.
request_id: The unique id of the request. request_id: The unique id of the request.
lora_request: LoRA request to use for generation, if any. lora_request: LoRA request to use for generation, if any.
...@@ -561,8 +562,9 @@ class MQLLMEngineClient(EngineClient): ...@@ -561,8 +562,9 @@ class MQLLMEngineClient(EngineClient):
from the LLMEngine to the caller. from the LLMEngine to the caller.
Args: Args:
prompt: The prompt to the LLM. See {class}`~vllm.inputs.PromptType` prompt: The prompt to the LLM. See
for more details about the format of each input. [`PromptType`][vllm.inputs.PromptType] for more details about
the format of each input.
pooling_params: The pooling parameters of the request. pooling_params: The pooling parameters of the request.
request_id: The unique id of the request. request_id: The unique id of the request.
lora_request: LoRA request to use for generation, if any. lora_request: LoRA request to use for generation, if any.
......
...@@ -42,19 +42,22 @@ HEALTHY_RESPONSE = (pickle.dumps(VLLM_RPC_SUCCESS_STR), ) ...@@ -42,19 +42,22 @@ HEALTHY_RESPONSE = (pickle.dumps(VLLM_RPC_SUCCESS_STR), )
class MQLLMEngine: class MQLLMEngine:
"""A multiprocessing wrapper for {class}`LLMEngine`. """A multiprocessing wrapper for
[`LLMEngine`][vllm.engine.llm_engine.LLMEngine].
This class is used to wrap the {class}`LLMEngine` class to enable use This class is used to wrap the
[`LLMEngine`][vllm.engine.llm_engine.LLMEngine] class to enable use
in concurrnet manner. It runs a background loop and uses zeromq to in concurrnet manner. It runs a background loop and uses zeromq to
receive new requests and stream outputs incrementally via ipc. receive new requests and stream outputs incrementally via ipc.
The {class}`LLMEngine` generate or encode process is kicked off when a new The [`LLMEngine`][vllm.engine.llm_engine.LLMEngine] generate or encode
RPCProcessRequest is received by the input_socket. process is kicked off when a new RPCProcessRequest is received by the
input_socket.
The self.engine_loop checks the input_socket for new requests, The self.engine_loop checks the input_socket for new requests,
adds them to the LLMEngine if there are any, calls the internal adds them to the LLMEngine if there are any, calls the internal
{class}`LLMEngine.step()`, and sends the RequestOutputs back over [`LLMEngine.step()`][vllm.engine.llm_engine.LLMEngine.step], and sends
the output_socket. the RequestOutputs back over the output_socket.
If use_async_sockets is set, the logic associated with reading new If use_async_sockets is set, the logic associated with reading new
requests from the socket and sending data to the socket is passed requests from the socket and sending data to the socket is passed
...@@ -65,8 +68,8 @@ class MQLLMEngine: ...@@ -65,8 +68,8 @@ class MQLLMEngine:
ipc_path: Base path for zeromq interprocess messaging ipc_path: Base path for zeromq interprocess messaging
use_async_sockets: Whether to make send/recv async with GPU use_async_sockets: Whether to make send/recv async with GPU
log_requests: Whether to log the requests. log_requests: Whether to log the requests.
*args: Arguments for {class}`LLMEngine`. *args: Arguments for [`LLMEngine`][vllm.engine.llm_engine.LLMEngine].
**kwargs: Arguments for {class}`LLMEngine`. **kwargs: Arguments for [`LLMEngine`][vllm.engine.llm_engine.LLMEngine].
""" """
def __init__(self, def __init__(self,
......
...@@ -56,8 +56,11 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor): ...@@ -56,8 +56,11 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
scheduled computation. scheduled computation.
Args: Args:
seq_group: the outputs are associated with this {class}`SequenceGroup` seq_group: the outputs are associated with this
outputs: the {class}`SequenceGroupOutput`s for all scheduler steps [`SequenceGroup`][vllm.sequence.SequenceGroup]
outputs: the
[`SequenceGroupOutput`][vllm.sequence.SequenceGroupOutput]s
for all scheduler steps
""" """
for output in outputs: for output in outputs:
# Concatenate single-step prompt logprob processing results. # Concatenate single-step prompt logprob processing results.
...@@ -67,7 +70,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor): ...@@ -67,7 +70,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
@staticmethod @staticmethod
@functools.lru_cache @functools.lru_cache
def _log_prompt_logprob_unsupported_warning_once(): def _log_prompt_logprob_unsupported_warning_once():
# Reminder: Please update docs/source/features/compatibility_matrix.md # Reminder: Please update docs/features/compatibility_matrix.md
# If the feature combo become valid # If the feature combo become valid
logger.warning( logger.warning(
"Prompt logprob is not supported by multi step workers. " "Prompt logprob is not supported by multi step workers. "
......
...@@ -19,17 +19,21 @@ logger = init_logger(__name__) ...@@ -19,17 +19,21 @@ logger = init_logger(__name__)
def single_step_process_prompt_logprob( def single_step_process_prompt_logprob(
sg_output_proc: SequenceGroupOutputProcessor, seq_group: SequenceGroup, sg_output_proc: SequenceGroupOutputProcessor, seq_group: SequenceGroup,
output: CompletionSequenceGroupOutput) -> None: output: CompletionSequenceGroupOutput) -> None:
"""Process prompt logprobs associated with the {class}`SequenceGroupOutput` """Process prompt logprobs associated with the
for a given step. [`SequenceGroupOutput`][vllm.sequence.SequenceGroupOutput] for a given step.
Do nothing if the output has no prompt logprobs. Do nothing if the output has no prompt logprobs.
Account for the fact that transformers do not compute first-token logprobs. Account for the fact that transformers do not compute first-token logprobs.
Args: Args:
sg_output_proc: {class}`SequenceGroupOutputProcessor` instance sg_output_proc:
seq_group: the output is associated with this {class}`SequenceGroup` [`SequenceGroupOutputProcessor`][vllm.engine.output_processor.interfaces.SequenceGroupOutputProcessor]
output: the {class}`SequenceGroupOutput` for a single scheduler step instance
seq_group: the output is associated with this
[`SequenceGroup`][vllm.sequence.SequenceGroup]
output: the [`SequenceGroupOutput`][vllm.sequence.SequenceGroupOutput]
for a single scheduler step
""" """
prompt_logprobs = output.prompt_logprobs prompt_logprobs = output.prompt_logprobs
...@@ -103,8 +107,11 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor): ...@@ -103,8 +107,11 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
scheduled computation. scheduled computation.
Args: Args:
seq_group: the output is associated with this {class}`SequenceGroup` seq_group: the output is associated with this
outputs: the {class}`SequenceGroupOutput` for a single scheduler step [`SequenceGroup`][vllm.sequence.SequenceGroup]
outputs: the
[`SequenceGroupOutput`][vllm.sequence.SequenceGroupOutput]
for a single scheduler step
""" """
assert len(outputs) == 1, "Single step should only have 1 output." assert len(outputs) == 1, "Single step should only have 1 output."
output = outputs[0] output = outputs[0]
......
...@@ -556,6 +556,8 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): ...@@ -556,6 +556,8 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
return "(<audio>./</audio>)" return "(<audio>./</audio>)"
raise TypeError(f"Unknown model type: {model_type}") raise TypeError(f"Unknown model type: {model_type}")
elif modality == "video": elif modality == "video":
if model_type == "internvl_chat":
return "<video>"
if model_type in ("qwen2_vl", "qwen2_5_vl"): if model_type in ("qwen2_vl", "qwen2_5_vl"):
return "<|vision_start|><|video_pad|><|vision_end|>" return "<|vision_start|><|video_pad|><|vision_end|>"
if model_type == "qwen2_5_omni": if model_type == "qwen2_5_omni":
......
...@@ -9,7 +9,7 @@ import vllm.entrypoints.cli.collect_env ...@@ -9,7 +9,7 @@ import vllm.entrypoints.cli.collect_env
import vllm.entrypoints.cli.openai import vllm.entrypoints.cli.openai
import vllm.entrypoints.cli.serve import vllm.entrypoints.cli.serve
import vllm.version import vllm.version
from vllm.entrypoints.utils import cli_env_setup from vllm.entrypoints.utils import VLLM_SERVE_PARSER_EPILOG, cli_env_setup
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
CMD_MODULES = [ CMD_MODULES = [
...@@ -32,7 +32,10 @@ def register_signal_handlers(): ...@@ -32,7 +32,10 @@ def register_signal_handlers():
def main(): def main():
cli_env_setup() cli_env_setup()
parser = FlexibleArgumentParser(description="vLLM CLI") parser = FlexibleArgumentParser(
description="vLLM CLI",
epilog=VLLM_SERVE_PARSER_EPILOG,
)
parser.add_argument('-v', parser.add_argument('-v',
'--version', '--version',
action='version', action='version',
......
...@@ -11,6 +11,8 @@ from vllm.entrypoints.cli.types import CLISubcommand ...@@ -11,6 +11,8 @@ from vllm.entrypoints.cli.types import CLISubcommand
from vllm.entrypoints.openai.api_server import run_server from vllm.entrypoints.openai.api_server import run_server
from vllm.entrypoints.openai.cli_args import (make_arg_parser, from vllm.entrypoints.openai.cli_args import (make_arg_parser,
validate_parsed_serve_args) validate_parsed_serve_args)
from vllm.entrypoints.utils import (VLLM_SERVE_PARSER_EPILOG,
show_filtered_argument_or_group_from_help)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser, get_tcp_uri from vllm.utils import FlexibleArgumentParser, get_tcp_uri
...@@ -77,7 +79,10 @@ class ServeSubcommand(CLISubcommand): ...@@ -77,7 +79,10 @@ class ServeSubcommand(CLISubcommand):
"https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html#cli-reference" "https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html#cli-reference"
) )
return make_arg_parser(serve_parser) serve_parser = make_arg_parser(serve_parser)
show_filtered_argument_or_group_from_help(serve_parser)
serve_parser.epilog = VLLM_SERVE_PARSER_EPILOG
return serve_parser
def cmd_init() -> list[CLISubcommand]: def cmd_init() -> list[CLISubcommand]:
......
...@@ -4,7 +4,8 @@ import itertools ...@@ -4,7 +4,8 @@ import itertools
import warnings import warnings
from collections.abc import Sequence from collections.abc import Sequence
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any, Callable, ClassVar, Optional, Union, cast, overload from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Optional, Union,
cast, overload)
import cloudpickle import cloudpickle
import torch.nn as nn import torch.nn as nn
...@@ -47,6 +48,9 @@ from vllm.usage.usage_lib import UsageContext ...@@ -47,6 +48,9 @@ from vllm.usage.usage_lib import UsageContext
from vllm.utils import (Counter, Device, deprecate_args, deprecate_kwargs, from vllm.utils import (Counter, Device, deprecate_args, deprecate_kwargs,
is_list_of) is_list_of)
if TYPE_CHECKING:
from vllm.v1.metrics.reader import Metric
logger = init_logger(__name__) logger = init_logger(__name__)
_R = TypeVar("_R", default=Any) _R = TypeVar("_R", default=Any)
...@@ -116,7 +120,8 @@ class LLM: ...@@ -116,7 +120,8 @@ class LLM:
to eager mode. Additionally for encoder-decoder models, if the to eager mode. Additionally for encoder-decoder models, if the
sequence length of the encoder input is larger than this, we fall sequence length of the encoder input is larger than this, we fall
back to the eager mode. back to the eager mode.
disable_custom_all_reduce: See {class}`~vllm.config.ParallelConfig` disable_custom_all_reduce: See
[ParallelConfig][vllm.config.ParallelConfig].
disable_async_output_proc: Disable async output processing. disable_async_output_proc: Disable async output processing.
This may result in lower performance. This may result in lower performance.
hf_token: The token to use as HTTP bearer authorization for remote files hf_token: The token to use as HTTP bearer authorization for remote files
...@@ -128,13 +133,11 @@ class LLM: ...@@ -128,13 +133,11 @@ class LLM:
compilation_config: Either an integer or a dictionary. If it is an compilation_config: Either an integer or a dictionary. If it is an
integer, it is used as the level of compilation optimization. If it integer, it is used as the level of compilation optimization. If it
is a dictionary, it can specify the full compilation configuration. is a dictionary, it can specify the full compilation configuration.
**kwargs: Arguments for {class}`~vllm.EngineArgs`. (See **kwargs: Arguments for [`EngineArgs`][vllm.EngineArgs].
{ref}`engine-args`)
:::{note} Note:
This class is intended to be used for offline inference. For online This class is intended to be used for offline inference. For online
serving, use the {class}`~vllm.AsyncLLMEngine` class instead. serving, use the [AsyncLLMEngine][vllm.AsyncLLMEngine] class instead.
:::
""" """
DEPRECATE_LEGACY: ClassVar[bool] = True DEPRECATE_LEGACY: ClassVar[bool] = True
...@@ -143,7 +146,7 @@ class LLM: ...@@ -143,7 +146,7 @@ class LLM:
DEPRECATE_INIT_POSARGS: ClassVar[bool] = True DEPRECATE_INIT_POSARGS: ClassVar[bool] = True
""" """
A flag to toggle whether to deprecate positional arguments in A flag to toggle whether to deprecate positional arguments in
{meth}`LLM.__init__`. [LLM.__init__][].
""" """
@classmethod @classmethod
...@@ -404,7 +407,7 @@ class LLM: ...@@ -404,7 +407,7 @@ class LLM:
Args: Args:
prompts: The prompts to the LLM. You may pass a sequence of prompts prompts: The prompts to the LLM. You may pass a sequence of prompts
for batch inference. See {class}`~vllm.inputs.PromptType` for batch inference. See [PromptType][vllm.inputs.PromptType]
for more details about the format of each prompts. for more details about the format of each prompts.
sampling_params: The sampling parameters for text generation. If sampling_params: The sampling parameters for text generation. If
None, we use the default sampling parameters. None, we use the default sampling parameters.
...@@ -422,11 +425,10 @@ class LLM: ...@@ -422,11 +425,10 @@ class LLM:
A list of `RequestOutput` objects containing the A list of `RequestOutput` objects containing the
generated completions in the same order as the input prompts. generated completions in the same order as the input prompts.
:::{note} Note:
Using `prompts` and `prompt_token_ids` as keyword parameters is Using `prompts` and `prompt_token_ids` as keyword parameters is
considered legacy and may be deprecated in the future. You should considered legacy and may be deprecated in the future. You should
instead pass them via the `inputs` parameter. instead pass them via the `inputs` parameter.
:::
""" """
runner_type = self.llm_engine.model_config.runner_type runner_type = self.llm_engine.model_config.runner_type
if runner_type not in ["generate", "transcription"]: if runner_type not in ["generate", "transcription"]:
...@@ -495,17 +497,16 @@ class LLM: ...@@ -495,17 +497,16 @@ class LLM:
`self` argument, in addition to the arguments passed in `args` `self` argument, in addition to the arguments passed in `args`
and `kwargs`. The `self` argument will be the worker object. and `kwargs`. The `self` argument will be the worker object.
timeout: Maximum time in seconds to wait for execution. Raises a timeout: Maximum time in seconds to wait for execution. Raises a
{exc}`TimeoutError` on timeout. `None` means wait indefinitely. [`TimeoutError`][] on timeout. `None` means wait indefinitely.
args: Positional arguments to pass to the worker method. args: Positional arguments to pass to the worker method.
kwargs: Keyword arguments to pass to the worker method. kwargs: Keyword arguments to pass to the worker method.
Returns: Returns:
A list containing the results from each worker. A list containing the results from each worker.
:::{note} Note:
It is recommended to use this API to only pass control messages, It is recommended to use this API to only pass control messages,
and set up data-plane communication to pass data. and set up data-plane communication to pass data.
:::
""" """
return self.llm_engine.collective_rpc(method, timeout, args, kwargs) return self.llm_engine.collective_rpc(method, timeout, args, kwargs)
...@@ -672,7 +673,7 @@ class LLM: ...@@ -672,7 +673,7 @@ class LLM:
Generate responses for a chat conversation. Generate responses for a chat conversation.
The chat conversation is converted into a text prompt using the The chat conversation is converted into a text prompt using the
tokenizer and calls the {meth}`generate` method to generate the tokenizer and calls the [generate][] method to generate the
responses. responses.
Multi-modal inputs can be passed in the same way you would pass them Multi-modal inputs can be passed in the same way you would pass them
...@@ -681,8 +682,8 @@ class LLM: ...@@ -681,8 +682,8 @@ class LLM:
Args: Args:
messages: A list of conversations or a single conversation. messages: A list of conversations or a single conversation.
- Each conversation is represented as a list of messages. - Each conversation is represented as a list of messages.
- Each message is a dictionary with 'role' and 'content' keys. - Each message is a dictionary with 'role' and 'content' keys.
sampling_params: The sampling parameters for text generation. sampling_params: The sampling parameters for text generation.
If None, we use the default sampling parameters. When it If None, we use the default sampling parameters. When it
...@@ -692,27 +693,27 @@ class LLM: ...@@ -692,27 +693,27 @@ class LLM:
use_tqdm: Whether to use tqdm to display the progress bar. use_tqdm: Whether to use tqdm to display the progress bar.
lora_request: LoRA request to use for generation, if any. lora_request: LoRA request to use for generation, if any.
chat_template: The template to use for structuring the chat. chat_template: The template to use for structuring the chat.
If not provided, the model's default chat template will be used. If not provided, the model's default chat template will be used.
chat_template_content_format: The format to render message content. chat_template_content_format: The format to render message content.
- "string" will render the content as a string. - "string" will render the content as a string.
Example: ``"Who are you?"`` Example: `"Who are you?"`
- "openai" will render the content as a list of dictionaries, - "openai" will render the content as a list of dictionaries,
similar to OpenAI schema. similar to OpenAI schema.
Example: ``[{"type": "text", "text": "Who are you?"}]`` Example: `[{"type": "text", "text": "Who are you?"}]`
add_generation_prompt: If True, adds a generation template add_generation_prompt: If True, adds a generation template
to each message. to each message.
continue_final_message: If True, continues the final message in continue_final_message: If True, continues the final message in
the conversation instead of starting a new one. Cannot be the conversation instead of starting a new one. Cannot be
``True`` if ``add_generation_prompt`` is also ``True``. `True` if `add_generation_prompt` is also `True`.
chat_template_kwargs: Additional kwargs to pass to the chat chat_template_kwargs: Additional kwargs to pass to the chat
template. template.
mm_processor_kwargs: Multimodal processor kwarg overrides for this mm_processor_kwargs: Multimodal processor kwarg overrides for this
chat request. Only used for offline requests. chat request. Only used for offline requests.
Returns: Returns:
A list of ``RequestOutput`` objects containing the generated A list of `RequestOutput` objects containing the generated
responses in the same order as the input messages. responses in the same order as the input messages.
""" """
list_of_messages: list[list[ChatCompletionMessageParam]] list_of_messages: list[list[ChatCompletionMessageParam]]
...@@ -911,7 +912,7 @@ class LLM: ...@@ -911,7 +912,7 @@ class LLM:
Args: Args:
prompts: The prompts to the LLM. You may pass a sequence of prompts prompts: The prompts to the LLM. You may pass a sequence of prompts
for batch inference. See {class}`~vllm.inputs.PromptType` for batch inference. See [PromptType][vllm.inputs.PromptType]
for more details about the format of each prompts. for more details about the format of each prompts.
pooling_params: The pooling parameters for pooling. If None, we pooling_params: The pooling parameters for pooling. If None, we
use the default pooling parameters. use the default pooling parameters.
...@@ -924,11 +925,10 @@ class LLM: ...@@ -924,11 +925,10 @@ class LLM:
A list of `PoolingRequestOutput` objects containing the A list of `PoolingRequestOutput` objects containing the
pooled hidden states in the same order as the input prompts. pooled hidden states in the same order as the input prompts.
:::{note} Note:
Using `prompts` and `prompt_token_ids` as keyword parameters is Using `prompts` and `prompt_token_ids` as keyword parameters is
considered legacy and may be deprecated in the future. You should considered legacy and may be deprecated in the future. You should
instead pass them via the `inputs` parameter. instead pass them via the `inputs` parameter.
:::
""" """
runner_type = self.llm_engine.model_config.runner_type runner_type = self.llm_engine.model_config.runner_type
if runner_type != "pooling": if runner_type != "pooling":
...@@ -1001,7 +1001,7 @@ class LLM: ...@@ -1001,7 +1001,7 @@ class LLM:
Args: Args:
prompts: The prompts to the LLM. You may pass a sequence of prompts prompts: The prompts to the LLM. You may pass a sequence of prompts
for batch inference. See {class}`~vllm.inputs.PromptType` for batch inference. See [PromptType][vllm.inputs.PromptType]
for more details about the format of each prompts. for more details about the format of each prompts.
pooling_params: The pooling parameters for pooling. If None, we pooling_params: The pooling parameters for pooling. If None, we
use the default pooling parameters. use the default pooling parameters.
...@@ -1011,7 +1011,7 @@ class LLM: ...@@ -1011,7 +1011,7 @@ class LLM:
generation, if any. generation, if any.
Returns: Returns:
A list of ``EmbeddingRequestOutput`` objects containing the A list of `EmbeddingRequestOutput` objects containing the
embedding vectors in the same order as the input prompts. embedding vectors in the same order as the input prompts.
""" """
if self.llm_engine.model_config.task != "embed": if self.llm_engine.model_config.task != "embed":
...@@ -1045,7 +1045,7 @@ class LLM: ...@@ -1045,7 +1045,7 @@ class LLM:
Args: Args:
prompts: The prompts to the LLM. You may pass a sequence of prompts prompts: The prompts to the LLM. You may pass a sequence of prompts
for batch inference. See {class}`~vllm.inputs.PromptType` for batch inference. See [PromptType][vllm.inputs.PromptType]
for more details about the format of each prompts. for more details about the format of each prompts.
use_tqdm: Whether to use tqdm to display the progress bar. use_tqdm: Whether to use tqdm to display the progress bar.
lora_request: LoRA request to use for generation, if any. lora_request: LoRA request to use for generation, if any.
...@@ -1053,7 +1053,7 @@ class LLM: ...@@ -1053,7 +1053,7 @@ class LLM:
generation, if any. generation, if any.
Returns: Returns:
A list of ``ClassificationRequestOutput`` objects containing the A list of `ClassificationRequestOutput` objects containing the
embedding vectors in the same order as the input prompts. embedding vectors in the same order as the input prompts.
""" """
if self.llm_engine.model_config.task != "classify": if self.llm_engine.model_config.task != "classify":
...@@ -1163,11 +1163,11 @@ class LLM: ...@@ -1163,11 +1163,11 @@ class LLM:
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> list[ScoringRequestOutput]: ) -> list[ScoringRequestOutput]:
"""Generate similarity scores for all pairs ``<text,text_pair>``. """Generate similarity scores for all pairs `<text,text_pair>`.
The inputs can be ``1 -> 1``, ``1 -> N`` or ``N -> N``. The inputs can be `1 -> 1`, `1 -> N` or `N -> N`.
In the ``1 - N`` case the ``text_1`` sentence will be replicated ``N`` In the `1 - N` case the `text_1` sentence will be replicated `N`
times to pair with the ``text_2`` sentences. times to pair with the `text_2` sentences.
The input pairs are used to build a list of prompts for the The input pairs are used to build a list of prompts for the
cross encoder model. This class automatically batches the prompts, cross encoder model. This class automatically batches the prompts,
considering the memory constraint. For the best performance, put all considering the memory constraint. For the best performance, put all
...@@ -1175,9 +1175,9 @@ class LLM: ...@@ -1175,9 +1175,9 @@ class LLM:
Args: Args:
text_1: can be a single prompt or a list of prompts, in which text_1: can be a single prompt or a list of prompts, in which
case it has to have the same length as the ``text_2`` list case it has to have the same length as the `text_2` list
text_2: The texts to pair with the query to form the input text_2: The texts to pair with the query to form the input
to the LLM. See {class}`~vllm.inputs.PromptType` for to the LLM. See [PromptType][vllm.inputs.PromptType] for
more details about the format of each prompts. more details about the format of each prompts.
use_tqdm: Whether to use tqdm to display the progress bar. use_tqdm: Whether to use tqdm to display the progress bar.
lora_request: LoRA request to use for generation, if any. lora_request: LoRA request to use for generation, if any.
...@@ -1185,7 +1185,7 @@ class LLM: ...@@ -1185,7 +1185,7 @@ class LLM:
generation, if any. generation, if any.
Returns: Returns:
A list of ``ScoringRequestOutput`` objects containing the A list of `ScoringRequestOutput` objects containing the
generated scores in the same order as the input prompts. generated scores in the same order as the input prompts.
""" """
runner_type = self.llm_engine.model_config.runner_type runner_type = self.llm_engine.model_config.runner_type
...@@ -1286,18 +1286,32 @@ class LLM: ...@@ -1286,18 +1286,32 @@ class LLM:
def wake_up(self, tags: Optional[list[str]] = None): def wake_up(self, tags: Optional[list[str]] = None):
""" """
Wake up the engine from sleep mode. See the {meth}`sleep` method Wake up the engine from sleep mode. See the [sleep][] method
for more details. for more details.
Args: Args:
tags: An optional list of tags to reallocate the engine memory tags: An optional list of tags to reallocate the engine memory
for specific memory allocations. Values must be in for specific memory allocations. Values must be in
("weights", "kv_cache",). If None, all memory is reallocated. `("weights", "kv_cache")`. If None, all memory is reallocated.
wake_up should be called with all tags (or None) before the wake_up should be called with all tags (or None) before the
engine is used again. engine is used again.
""" """
self.llm_engine.wake_up(tags) self.llm_engine.wake_up(tags)
def get_metrics(self) -> list["Metric"]:
"""Return a snapshot of aggregated metrics from Prometheus.
Returns:
A ``MetricSnapshot`` instance capturing the current state
of all aggregated metrics from Prometheus.
Note:
This method is only available with the V1 LLM engine.
"""
from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
assert isinstance(self.llm_engine, V1LLMEngine)
return self.llm_engine.get_metrics()
# LEGACY # LEGACY
def _convert_v1_inputs( def _convert_v1_inputs(
self, self,
...@@ -1306,27 +1320,25 @@ class LLM: ...@@ -1306,27 +1320,25 @@ class LLM:
): ):
# skip_tokenizer_init is now checked in engine # skip_tokenizer_init is now checked in engine
if prompts is None and prompt_token_ids is None:
raise ValueError(
"Either prompts or prompt_token_ids must be provided.")
if prompts is not None and prompt_token_ids is not None \
and len(prompts) != len(prompt_token_ids):
raise ValueError(
"The lengths of prompts and prompt_token_ids must be the same."
)
if prompts is not None: if prompts is not None:
prompts = [p["content"] for p in parse_and_batch_prompt(prompts)] prompts = [p["content"] for p in parse_and_batch_prompt(prompts)]
if prompt_token_ids is not None: if prompt_token_ids is not None:
prompt_token_ids = [ prompt_token_ids = [
p["content"] for p in parse_and_batch_prompt(prompt_token_ids) p["content"] for p in parse_and_batch_prompt(prompt_token_ids)
] ]
num_requests = None
if prompts is not None: if prompts is not None:
num_requests = len(prompts) num_requests = len(prompts)
if prompt_token_ids is not None: elif prompt_token_ids is not None:
if (num_requests is not None
and num_requests != len(prompt_token_ids)):
raise ValueError("The lengths of prompts and prompt_token_ids "
"must be the same.")
num_requests = len(prompt_token_ids) num_requests = len(prompt_token_ids)
if num_requests is None:
raise ValueError("Either prompts or prompt_token_ids must be "
"provided.")
parsed_prompts: list[PromptType] = [] parsed_prompts: list[PromptType] = []
for i in range(num_requests): for i in range(num_requests):
item: PromptType item: PromptType
......
...@@ -7,7 +7,6 @@ import importlib ...@@ -7,7 +7,6 @@ import importlib
import inspect import inspect
import multiprocessing import multiprocessing
import os import os
import re
import signal import signal
import socket import socket
import tempfile import tempfile
...@@ -21,6 +20,7 @@ from json import JSONDecodeError ...@@ -21,6 +20,7 @@ from json import JSONDecodeError
from typing import Annotated, Optional, Union from typing import Annotated, Optional, Union
import prometheus_client import prometheus_client
import regex as re
import uvloop import uvloop
from fastapi import APIRouter, Depends, FastAPI, Form, HTTPException, Request from fastapi import APIRouter, Depends, FastAPI, Form, HTTPException, Request
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
......
...@@ -3,11 +3,11 @@ ...@@ -3,11 +3,11 @@
# Adapted from # Adapted from
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py # https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
import json import json
import re
import time import time
from http import HTTPStatus from http import HTTPStatus
from typing import Annotated, Any, ClassVar, Literal, Optional, Union from typing import Annotated, Any, ClassVar, Literal, Optional, Union
import regex as re
import torch import torch
from fastapi import HTTPException, UploadFile from fastapi import HTTPException, UploadFile
from pydantic import (BaseModel, ConfigDict, Field, TypeAdapter, from pydantic import (BaseModel, ConfigDict, Field, TypeAdapter,
...@@ -251,7 +251,7 @@ class ChatCompletionRequest(OpenAIBaseModel): ...@@ -251,7 +251,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
parallel_tool_calls: Optional[bool] = False parallel_tool_calls: Optional[bool] = False
user: Optional[str] = None user: Optional[str] = None
# doc: begin-chat-completion-sampling-params # --8<-- [start:chat-completion-sampling-params]
best_of: Optional[int] = None best_of: Optional[int] = None
use_beam_search: bool = False use_beam_search: bool = False
top_k: Optional[int] = None top_k: Optional[int] = None
...@@ -266,9 +266,9 @@ class ChatCompletionRequest(OpenAIBaseModel): ...@@ -266,9 +266,9 @@ class ChatCompletionRequest(OpenAIBaseModel):
spaces_between_special_tokens: bool = True spaces_between_special_tokens: bool = True
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
prompt_logprobs: Optional[int] = None prompt_logprobs: Optional[int] = None
# doc: end-chat-completion-sampling-params # --8<-- [end:chat-completion-sampling-params]
# doc: begin-chat-completion-extra-params # --8<-- [start:chat-completion-extra-params]
echo: bool = Field( echo: bool = Field(
default=False, default=False,
description=( description=(
...@@ -407,7 +407,7 @@ class ChatCompletionRequest(OpenAIBaseModel): ...@@ -407,7 +407,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
default=None, default=None,
description="KVTransfer parameters used for disaggregated serving.") description="KVTransfer parameters used for disaggregated serving.")
# doc: end-chat-completion-extra-params # --8<-- [end:chat-completion-extra-params]
# Default sampling parameters for chat completion requests # Default sampling parameters for chat completion requests
_DEFAULT_SAMPLING_PARAMS: dict = { _DEFAULT_SAMPLING_PARAMS: dict = {
...@@ -764,7 +764,7 @@ class CompletionRequest(OpenAIBaseModel): ...@@ -764,7 +764,7 @@ class CompletionRequest(OpenAIBaseModel):
top_p: Optional[float] = None top_p: Optional[float] = None
user: Optional[str] = None user: Optional[str] = None
# doc: begin-completion-sampling-params # --8<-- [start:completion-sampling-params]
use_beam_search: bool = False use_beam_search: bool = False
top_k: Optional[int] = None top_k: Optional[int] = None
min_p: Optional[float] = None min_p: Optional[float] = None
...@@ -779,9 +779,9 @@ class CompletionRequest(OpenAIBaseModel): ...@@ -779,9 +779,9 @@ class CompletionRequest(OpenAIBaseModel):
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
allowed_token_ids: Optional[list[int]] = None allowed_token_ids: Optional[list[int]] = None
prompt_logprobs: Optional[int] = None prompt_logprobs: Optional[int] = None
# doc: end-completion-sampling-params # --8<-- [end:completion-sampling-params]
# doc: begin-completion-extra-params # --8<-- [start:completion-extra-params]
add_special_tokens: bool = Field( add_special_tokens: bool = Field(
default=True, default=True,
description=( description=(
...@@ -858,7 +858,7 @@ class CompletionRequest(OpenAIBaseModel): ...@@ -858,7 +858,7 @@ class CompletionRequest(OpenAIBaseModel):
default=None, default=None,
description="KVTransfer parameters used for disaggregated serving.") description="KVTransfer parameters used for disaggregated serving.")
# doc: end-completion-extra-params # --8<-- [end:completion-extra-params]
# Default sampling parameters for completion requests # Default sampling parameters for completion requests
_DEFAULT_SAMPLING_PARAMS: dict = { _DEFAULT_SAMPLING_PARAMS: dict = {
...@@ -1045,11 +1045,11 @@ class EmbeddingCompletionRequest(OpenAIBaseModel): ...@@ -1045,11 +1045,11 @@ class EmbeddingCompletionRequest(OpenAIBaseModel):
user: Optional[str] = None user: Optional[str] = None
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None
# doc: begin-embedding-pooling-params # --8<-- [start:embedding-pooling-params]
additional_data: Optional[Any] = None additional_data: Optional[Any] = None
# doc: end-embedding-pooling-params # --8<-- [end:embedding-pooling-params]
# doc: begin-embedding-extra-params # --8<-- [start:embedding-extra-params]
add_special_tokens: bool = Field( add_special_tokens: bool = Field(
default=True, default=True,
description=( description=(
...@@ -1064,7 +1064,7 @@ class EmbeddingCompletionRequest(OpenAIBaseModel): ...@@ -1064,7 +1064,7 @@ class EmbeddingCompletionRequest(OpenAIBaseModel):
"if the served model does not use priority scheduling."), "if the served model does not use priority scheduling."),
) )
# doc: end-embedding-extra-params # --8<-- [end:embedding-extra-params]
def to_pooling_params(self): def to_pooling_params(self):
return PoolingParams(dimensions=self.dimensions, return PoolingParams(dimensions=self.dimensions,
...@@ -1080,11 +1080,11 @@ class EmbeddingChatRequest(OpenAIBaseModel): ...@@ -1080,11 +1080,11 @@ class EmbeddingChatRequest(OpenAIBaseModel):
user: Optional[str] = None user: Optional[str] = None
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None
# doc: begin-chat-embedding-pooling-params # --8<-- [start:chat-embedding-pooling-params]
additional_data: Optional[Any] = None additional_data: Optional[Any] = None
# doc: end-chat-embedding-pooling-params # --8<-- [end:chat-embedding-pooling-params]
# doc: begin-chat-embedding-extra-params # --8<-- [start:chat-embedding-extra-params]
add_special_tokens: bool = Field( add_special_tokens: bool = Field(
default=False, default=False,
description=( description=(
...@@ -1118,7 +1118,7 @@ class EmbeddingChatRequest(OpenAIBaseModel): ...@@ -1118,7 +1118,7 @@ class EmbeddingChatRequest(OpenAIBaseModel):
"default: 0). Any priority other than 0 will raise an error " "default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."), "if the served model does not use priority scheduling."),
) )
# doc: end-chat-embedding-extra-params # --8<-- [end:chat-embedding-extra-params]
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
...@@ -1147,11 +1147,11 @@ class ScoreRequest(OpenAIBaseModel): ...@@ -1147,11 +1147,11 @@ class ScoreRequest(OpenAIBaseModel):
text_2: Union[list[str], str] text_2: Union[list[str], str]
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None
# doc: begin-score-pooling-params # --8<-- [start:score-pooling-params]
additional_data: Optional[Any] = None additional_data: Optional[Any] = None
# doc: end-score-pooling-params # --8<-- [end:score-pooling-params]
# doc: begin-score-extra-params # --8<-- [start:score-extra-params]
priority: int = Field( priority: int = Field(
default=0, default=0,
description=( description=(
...@@ -1160,7 +1160,7 @@ class ScoreRequest(OpenAIBaseModel): ...@@ -1160,7 +1160,7 @@ class ScoreRequest(OpenAIBaseModel):
"if the served model does not use priority scheduling."), "if the served model does not use priority scheduling."),
) )
# doc: end-score-extra-params # --8<-- [end:score-extra-params]
def to_pooling_params(self): def to_pooling_params(self):
return PoolingParams(additional_data=self.additional_data) return PoolingParams(additional_data=self.additional_data)
...@@ -1173,11 +1173,11 @@ class RerankRequest(OpenAIBaseModel): ...@@ -1173,11 +1173,11 @@ class RerankRequest(OpenAIBaseModel):
top_n: int = Field(default_factory=lambda: 0) top_n: int = Field(default_factory=lambda: 0)
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None
# doc: begin-rerank-pooling-params # --8<-- [start:rerank-pooling-params]
additional_data: Optional[Any] = None additional_data: Optional[Any] = None
# doc: end-rerank-pooling-params # --8<-- [end:rerank-pooling-params]
# doc: begin-rerank-extra-params # --8<-- [start:rerank-extra-params]
priority: int = Field( priority: int = Field(
default=0, default=0,
description=( description=(
...@@ -1186,7 +1186,7 @@ class RerankRequest(OpenAIBaseModel): ...@@ -1186,7 +1186,7 @@ class RerankRequest(OpenAIBaseModel):
"if the served model does not use priority scheduling."), "if the served model does not use priority scheduling."),
) )
# doc: end-rerank-extra-params # --8<-- [end:rerank-extra-params]
def to_pooling_params(self): def to_pooling_params(self):
return PoolingParams(additional_data=self.additional_data) return PoolingParams(additional_data=self.additional_data)
...@@ -1321,11 +1321,11 @@ class ClassificationRequest(OpenAIBaseModel): ...@@ -1321,11 +1321,11 @@ class ClassificationRequest(OpenAIBaseModel):
truncate_prompt_tokens: Optional[int] = None truncate_prompt_tokens: Optional[int] = None
user: Optional[str] = None user: Optional[str] = None
# doc: begin-classification-pooling-params # --8<-- [start:classification-pooling-params]
additional_data: Optional[Any] = None additional_data: Optional[Any] = None
# doc: end-classification-pooling-params # --8<-- [end:classification-pooling-params]
# doc: begin-classification-extra-params # --8<-- [start:classification-extra-params]
priority: int = Field( priority: int = Field(
default=0, default=0,
description=( description=(
...@@ -1334,7 +1334,7 @@ class ClassificationRequest(OpenAIBaseModel): ...@@ -1334,7 +1334,7 @@ class ClassificationRequest(OpenAIBaseModel):
"if the served model does not use priority scheduling."), "if the served model does not use priority scheduling."),
) )
# doc: end-classification-extra-params # --8<-- [end:classification-extra-params]
def to_pooling_params(self): def to_pooling_params(self):
return PoolingParams(additional_data=self.additional_data) return PoolingParams(additional_data=self.additional_data)
...@@ -1698,7 +1698,7 @@ class TranscriptionRequest(OpenAIBaseModel): ...@@ -1698,7 +1698,7 @@ class TranscriptionRequest(OpenAIBaseModel):
timestamps incurs additional latency. timestamps incurs additional latency.
""" """
# doc: begin-transcription-extra-params # --8<-- [start:transcription-extra-params]
stream: Optional[bool] = False stream: Optional[bool] = False
"""Custom field not present in the original OpenAI definition. When set, """Custom field not present in the original OpenAI definition. When set,
it will enable output to be streamed in a similar fashion as the Chat it will enable output to be streamed in a similar fashion as the Chat
...@@ -1707,9 +1707,9 @@ class TranscriptionRequest(OpenAIBaseModel): ...@@ -1707,9 +1707,9 @@ class TranscriptionRequest(OpenAIBaseModel):
# Flattened stream option to simplify form data. # Flattened stream option to simplify form data.
stream_include_usage: Optional[bool] = False stream_include_usage: Optional[bool] = False
stream_continuous_usage_stats: Optional[bool] = False stream_continuous_usage_stats: Optional[bool] = False
# doc: end-transcription-extra-params # --8<-- [end:transcription-extra-params]
# doc: begin-transcription-sampling-params # --8<-- [start:transcription-sampling-params]
temperature: float = Field(default=0.0) temperature: float = Field(default=0.0)
"""The sampling temperature, between 0 and 1. """The sampling temperature, between 0 and 1.
...@@ -1743,7 +1743,7 @@ class TranscriptionRequest(OpenAIBaseModel): ...@@ -1743,7 +1743,7 @@ class TranscriptionRequest(OpenAIBaseModel):
presence_penalty: Optional[float] = 0.0 presence_penalty: Optional[float] = 0.0
"""The presence penalty to use for sampling.""" """The presence penalty to use for sampling."""
# doc: end-transcription-sampling-params # --8<-- [end:transcription-sampling-params]
# Default sampling parameters for transcription requests. # Default sampling parameters for transcription requests.
_DEFAULT_SAMPLING_PARAMS: dict = { _DEFAULT_SAMPLING_PARAMS: dict = {
......
...@@ -365,8 +365,8 @@ async def main(args): ...@@ -365,8 +365,8 @@ async def main(args):
# Determine the type of request and run it. # Determine the type of request and run it.
if request.url == "/v1/chat/completions": if request.url == "/v1/chat/completions":
chat_handler_fn = (None if openai_serving_chat is None else chat_handler_fn = openai_serving_chat.create_chat_completion if \
openai_serving_chat.create_chat_completion) openai_serving_chat is not None else None
if chat_handler_fn is None: if chat_handler_fn is None:
response_futures.append( response_futures.append(
make_async_error_request_output( make_async_error_request_output(
...@@ -380,8 +380,8 @@ async def main(args): ...@@ -380,8 +380,8 @@ async def main(args):
run_request(chat_handler_fn, request, tracker)) run_request(chat_handler_fn, request, tracker))
tracker.submitted() tracker.submitted()
elif request.url == "/v1/embeddings": elif request.url == "/v1/embeddings":
embed_handler_fn = (None if openai_serving_embedding is None else embed_handler_fn = openai_serving_embedding.create_embedding if \
openai_serving_embedding.create_embedding) openai_serving_embedding is not None else None
if embed_handler_fn is None: if embed_handler_fn is None:
response_futures.append( response_futures.append(
make_async_error_request_output( make_async_error_request_output(
...@@ -394,8 +394,8 @@ async def main(args): ...@@ -394,8 +394,8 @@ async def main(args):
run_request(embed_handler_fn, request, tracker)) run_request(embed_handler_fn, request, tracker))
tracker.submitted() tracker.submitted()
elif request.url == "/v1/score": elif request.url == "/v1/score":
score_handler_fn = (None if openai_serving_scores is None else score_handler_fn = openai_serving_scores.create_score if \
openai_serving_scores.create_score) openai_serving_scores is not None else None
if score_handler_fn is None: if score_handler_fn is None:
response_futures.append( response_futures.append(
make_async_error_request_output( make_async_error_request_output(
......
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
import asyncio import asyncio
import json import json
import re
import time import time
from collections.abc import AsyncGenerator, AsyncIterator from collections.abc import AsyncGenerator, AsyncIterator
from collections.abc import Sequence as GenericSequence from collections.abc import Sequence as GenericSequence
...@@ -10,6 +9,7 @@ from typing import Callable, Final, Optional, Union ...@@ -10,6 +9,7 @@ from typing import Callable, Final, Optional, Union
import jinja2 import jinja2
import partial_json_parser import partial_json_parser
import regex as re
from fastapi import Request from fastapi import Request
from pydantic import TypeAdapter from pydantic import TypeAdapter
......
...@@ -582,7 +582,8 @@ class OpenAIServing: ...@@ -582,7 +582,8 @@ class OpenAIServing:
add_special_tokens: bool = True, add_special_tokens: bool = True,
) -> TextTokensPrompt: ) -> TextTokensPrompt:
""" """
A simpler implementation of {meth}`_tokenize_prompt_input_or_inputs` A simpler implementation of
[`_tokenize_prompt_input_or_inputs`][vllm.entrypoints.openai.serving_engine.OpenAIServing._tokenize_prompt_input_or_inputs]
that assumes single input. that assumes single input.
""" """
return next( return next(
...@@ -603,7 +604,8 @@ class OpenAIServing: ...@@ -603,7 +604,8 @@ class OpenAIServing:
add_special_tokens: bool = True, add_special_tokens: bool = True,
) -> Iterator[TextTokensPrompt]: ) -> Iterator[TextTokensPrompt]:
""" """
A simpler implementation of {meth}`_tokenize_prompt_input_or_inputs` A simpler implementation of
[`_tokenize_prompt_input_or_inputs`][vllm.entrypoints.openai.serving_engine.OpenAIServing._tokenize_prompt_input_or_inputs]
that assumes multiple inputs. that assumes multiple inputs.
""" """
for text in prompt_inputs: for text in prompt_inputs:
......
...@@ -7,6 +7,7 @@ from .granite_tool_parser import GraniteToolParser ...@@ -7,6 +7,7 @@ from .granite_tool_parser import GraniteToolParser
from .hermes_tool_parser import Hermes2ProToolParser from .hermes_tool_parser import Hermes2ProToolParser
from .internlm2_tool_parser import Internlm2ToolParser from .internlm2_tool_parser import Internlm2ToolParser
from .jamba_tool_parser import JambaToolParser from .jamba_tool_parser import JambaToolParser
from .llama4_pythonic_tool_parser import Llama4PythonicToolParser
from .llama_tool_parser import Llama3JsonToolParser from .llama_tool_parser import Llama3JsonToolParser
from .mistral_tool_parser import MistralToolParser from .mistral_tool_parser import MistralToolParser
from .phi4mini_tool_parser import Phi4MiniJsonToolParser from .phi4mini_tool_parser import Phi4MiniJsonToolParser
...@@ -16,5 +17,6 @@ __all__ = [ ...@@ -16,5 +17,6 @@ __all__ = [
"ToolParser", "ToolParserManager", "Granite20bFCToolParser", "ToolParser", "ToolParserManager", "Granite20bFCToolParser",
"GraniteToolParser", "Hermes2ProToolParser", "MistralToolParser", "GraniteToolParser", "Hermes2ProToolParser", "MistralToolParser",
"Internlm2ToolParser", "Llama3JsonToolParser", "JambaToolParser", "Internlm2ToolParser", "Llama3JsonToolParser", "JambaToolParser",
"PythonicToolParser", "Phi4MiniJsonToolParser", "DeepSeekV3ToolParser" "Llama4PythonicToolParser", "PythonicToolParser", "Phi4MiniJsonToolParser",
"DeepSeekV3ToolParser"
] ]
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import re
from collections.abc import Sequence from collections.abc import Sequence
from typing import Union from typing import Union
import regex as re
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaFunctionCall, DeltaMessage, DeltaFunctionCall, DeltaMessage,
DeltaToolCall, DeltaToolCall,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import json import json
import re
from collections.abc import Sequence from collections.abc import Sequence
from json import JSONDecoder from json import JSONDecoder
from typing import Union from typing import Union
import partial_json_parser import partial_json_parser
import regex as re
from partial_json_parser.core.options import Allow from partial_json_parser.core.options import Allow
from vllm.entrypoints.chat_utils import random_tool_call_id from vllm.entrypoints.chat_utils import random_tool_call_id
...@@ -80,7 +80,8 @@ class Granite20bFCToolParser(ToolParser): ...@@ -80,7 +80,8 @@ class Granite20bFCToolParser(ToolParser):
function=FunctionCall( function=FunctionCall(
name=function_call["name"], name=function_call["name"],
# function call args are JSON but as a string # function call args are JSON but as a string
arguments=json.dumps(function_call["arguments"]), arguments=json.dumps(function_call["arguments"],
ensure_ascii=False),
), ),
) for function_call in raw_function_calls ) for function_call in raw_function_calls
] ]
...@@ -166,7 +167,8 @@ class Granite20bFCToolParser(ToolParser): ...@@ -166,7 +167,8 @@ class Granite20bFCToolParser(ToolParser):
if self.current_tool_id >= 0: if self.current_tool_id >= 0:
cur_arguments = current_tool_call.get("arguments") cur_arguments = current_tool_call.get("arguments")
if cur_arguments: if cur_arguments:
cur_args_json = json.dumps(cur_arguments) cur_args_json = json.dumps(cur_arguments,
ensure_ascii=False)
sent = len( sent = len(
self.streamed_args_for_tool[self.current_tool_id]) self.streamed_args_for_tool[self.current_tool_id])
argument_diff = cur_args_json[sent:] argument_diff = cur_args_json[sent:]
...@@ -218,7 +220,8 @@ class Granite20bFCToolParser(ToolParser): ...@@ -218,7 +220,8 @@ class Granite20bFCToolParser(ToolParser):
if cur_arguments: if cur_arguments:
sent = len( sent = len(
self.streamed_args_for_tool[self.current_tool_id]) self.streamed_args_for_tool[self.current_tool_id])
cur_args_json = json.dumps(cur_arguments) cur_args_json = json.dumps(cur_arguments,
ensure_ascii=False)
prev_arguments = self.prev_tool_call_arr[ prev_arguments = self.prev_tool_call_arr[
self.current_tool_id].get("arguments") self.current_tool_id].get("arguments")
...@@ -226,7 +229,8 @@ class Granite20bFCToolParser(ToolParser): ...@@ -226,7 +229,8 @@ class Granite20bFCToolParser(ToolParser):
if is_complete[self.current_tool_id]: if is_complete[self.current_tool_id]:
argument_diff = cur_args_json[sent:] argument_diff = cur_args_json[sent:]
elif prev_arguments: elif prev_arguments:
prev_args_json = json.dumps(prev_arguments) prev_args_json = json.dumps(prev_arguments,
ensure_ascii=False)
if cur_args_json != prev_args_json: if cur_args_json != prev_args_json:
prefix = find_common_prefix( prefix = find_common_prefix(
......
...@@ -67,7 +67,8 @@ class GraniteToolParser(ToolParser): ...@@ -67,7 +67,8 @@ class GraniteToolParser(ToolParser):
function=FunctionCall( function=FunctionCall(
name=function_call["name"], name=function_call["name"],
# function call args are JSON but as a string # function call args are JSON but as a string
arguments=json.dumps(function_call["arguments"]), arguments=json.dumps(function_call["arguments"],
ensure_ascii=False),
), ),
) for function_call in raw_function_calls ) for function_call in raw_function_calls
] ]
...@@ -151,7 +152,8 @@ class GraniteToolParser(ToolParser): ...@@ -151,7 +152,8 @@ class GraniteToolParser(ToolParser):
if self.current_tool_id >= 0: if self.current_tool_id >= 0:
cur_arguments = current_tool_call.get("arguments") cur_arguments = current_tool_call.get("arguments")
if cur_arguments: if cur_arguments:
cur_args_json = json.dumps(cur_arguments) cur_args_json = json.dumps(cur_arguments,
ensure_ascii=False)
sent = len( sent = len(
self.streamed_args_for_tool[self.current_tool_id]) self.streamed_args_for_tool[self.current_tool_id])
argument_diff = cur_args_json[sent:] argument_diff = cur_args_json[sent:]
...@@ -197,7 +199,8 @@ class GraniteToolParser(ToolParser): ...@@ -197,7 +199,8 @@ class GraniteToolParser(ToolParser):
if cur_arguments: if cur_arguments:
sent = len( sent = len(
self.streamed_args_for_tool[self.current_tool_id]) self.streamed_args_for_tool[self.current_tool_id])
cur_args_json = json.dumps(cur_arguments) cur_args_json = json.dumps(cur_arguments,
ensure_ascii=False)
prev_arguments = self.prev_tool_call_arr[ prev_arguments = self.prev_tool_call_arr[
self.current_tool_id].get("arguments") self.current_tool_id].get("arguments")
...@@ -205,7 +208,8 @@ class GraniteToolParser(ToolParser): ...@@ -205,7 +208,8 @@ class GraniteToolParser(ToolParser):
if is_complete[self.current_tool_id]: if is_complete[self.current_tool_id]:
argument_diff = cur_args_json[sent:] argument_diff = cur_args_json[sent:]
elif prev_arguments: elif prev_arguments:
prev_args_json = json.dumps(prev_arguments) prev_args_json = json.dumps(prev_arguments,
ensure_ascii=False)
if cur_args_json != prev_args_json: if cur_args_json != prev_args_json:
prefix = find_common_prefix( prefix = find_common_prefix(
prev_args_json, cur_args_json) prev_args_json, cur_args_json)
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import json import json
import re
from collections.abc import Sequence from collections.abc import Sequence
from typing import Union from typing import Union
import partial_json_parser import partial_json_parser
import regex as re
from partial_json_parser.core.options import Allow from partial_json_parser.core.options import Allow
from vllm.entrypoints.chat_utils import random_tool_call_id from vllm.entrypoints.chat_utils import random_tool_call_id
......
...@@ -133,7 +133,8 @@ class Internlm2ToolParser(ToolParser): ...@@ -133,7 +133,8 @@ class Internlm2ToolParser(ToolParser):
delta = None delta = None
# first time to get parameters # first time to get parameters
elif cur_arguments and not prev_arguments: elif cur_arguments and not prev_arguments:
cur_arguments_json = json.dumps(cur_arguments) cur_arguments_json = json.dumps(cur_arguments,
ensure_ascii=False)
arguments_delta = cur_arguments_json[:cur_arguments_json. arguments_delta = cur_arguments_json[:cur_arguments_json.
index(delta_text) + index(delta_text) +
...@@ -148,8 +149,10 @@ class Internlm2ToolParser(ToolParser): ...@@ -148,8 +149,10 @@ class Internlm2ToolParser(ToolParser):
self.current_tool_id] += arguments_delta self.current_tool_id] += arguments_delta
# both prev and cur parameters, send the increase parameters # both prev and cur parameters, send the increase parameters
elif cur_arguments and prev_arguments: elif cur_arguments and prev_arguments:
cur_args_json = json.dumps(cur_arguments) cur_args_json = json.dumps(cur_arguments,
prev_args_json = json.dumps(prev_arguments) ensure_ascii=False)
prev_args_json = json.dumps(prev_arguments,
ensure_ascii=False)
argument_diff = extract_intermediate_diff( argument_diff = extract_intermediate_diff(
cur_args_json, prev_args_json) cur_args_json, prev_args_json)
...@@ -190,7 +193,8 @@ class Internlm2ToolParser(ToolParser): ...@@ -190,7 +193,8 @@ class Internlm2ToolParser(ToolParser):
action_dict = json.loads(action) action_dict = json.loads(action)
name, parameters = action_dict['name'], json.dumps( name, parameters = action_dict['name'], json.dumps(
action_dict.get('parameters', action_dict.get('arguments', action_dict.get('parameters', action_dict.get('arguments',
{}))) {})),
ensure_ascii=False)
if not tools or name not in [t.function.name for t in tools]: if not tools or name not in [t.function.name for t in tools]:
ExtractedToolCallInformation(tools_called=False, ExtractedToolCallInformation(tools_called=False,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment