Commit 4eabe123 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge remote-tracking branch 'mirror/releases/v0.9.0' into v0.9.0-ori

parents 45840cd2 58738772
# SPDX-License-Identifier: Apache-2.0
import json
import re
from collections.abc import Sequence
from typing import Union
import partial_json_parser
import regex as re
from partial_json_parser.core.options import Allow
from vllm.entrypoints.chat_utils import random_tool_call_id
......@@ -96,8 +96,9 @@ class JambaToolParser(ToolParser):
function=FunctionCall(
name=function_call["name"],
# function call args are JSON but as a string
arguments=json.dumps(function_call["arguments"])))
for function_call in raw_function_calls
arguments=json.dumps(function_call["arguments"],
ensure_ascii=False),
)) for function_call in raw_function_calls
]
content = model_output[:model_output.
......@@ -187,7 +188,7 @@ class JambaToolParser(ToolParser):
diff: Union[str, None] = current_tool_call.get("arguments")
if diff:
diff = json.dumps(diff).replace(
diff = json.dumps(diff, ensure_ascii=False).replace(
self.streamed_args_for_tool[self.current_tool_id],
"")
delta = DeltaMessage(tool_calls=[
......@@ -248,7 +249,8 @@ class JambaToolParser(ToolParser):
"mid-arguments")
delta = None
elif cur_arguments and not prev_arguments:
cur_arguments_json = json.dumps(cur_arguments)
cur_arguments_json = json.dumps(cur_arguments,
ensure_ascii=False)
logger.debug("finding %s in %s", new_text,
cur_arguments_json)
......@@ -267,8 +269,10 @@ class JambaToolParser(ToolParser):
self.current_tool_id] += arguments_delta
elif cur_arguments and prev_arguments:
cur_args_json = json.dumps(cur_arguments)
prev_args_json = json.dumps(prev_arguments)
cur_args_json = json.dumps(cur_arguments,
ensure_ascii=False)
prev_args_json = json.dumps(prev_arguments,
ensure_ascii=False)
logger.debug("Searching for diff between \n%s\n%s",
cur_args_json, prev_args_json)
......
# SPDX-License-Identifier: Apache-2.0
import ast
import json
from collections.abc import Sequence
from typing import Any, Union
import regex as re
from transformers import PreTrainedTokenizerBase
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaFunctionCall, DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall, ToolCall)
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser, ToolParserManager)
from vllm.logger import init_logger
logger = init_logger(__name__)
class _UnexpectedAstError(Exception):
pass
@ToolParserManager.register_module("llama4_pythonic")
class Llama4PythonicToolParser(ToolParser):
"""
Toolcall parser for Llama4 that produce tool calls in a pythonic style
Use --enable-auto-tool-choice --tool-call-parser llama4_pythonic
"""
# TODO(mdepinet): Possible future improvements:
# 1. Support text + tools separated by either <|python_tag|> or \n\n
# 2. Support tools outside of a list (or separated by a semicolon).
# This depends on item 1 for consistent streaming.
# Neither of these are necessary for e.g. ToolACE, but both would help make
# Llama3.2 models more reliable.
TOOL_CALL_REGEX = re.compile(
r"\[([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s)?\),\s*)*([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s*)?\)\s*)+\]",
re.DOTALL)
def __init__(self, tokenizer: PreTrainedTokenizerBase):
super().__init__(tokenizer)
# Rename for readability. This is NOT a tool id.
@property
def current_tool_index(self) -> int:
return self.current_tool_id
@current_tool_index.setter
def current_tool_index(self, value: int) -> None:
self.current_tool_id = value
def extract_tool_calls(
self, model_output: str,
request: ChatCompletionRequest) -> ExtractedToolCallInformation:
"""
Extract the tool calls from a complete model response.
"""
# remove <|python_start|> and <|python_end|>
# as Llama 4 model sometime will output those tokens
if model_output.startswith("<|python_start|>"):
model_output = model_output[len("<|python_start|>"):]
model_output = model_output.replace("<|python_end|>", "")
if not (self.TOOL_CALL_REGEX.match(model_output)):
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)
try:
module = ast.parse(model_output)
parsed = getattr(module.body[0], "value", None)
if isinstance(parsed, ast.List) and all(
isinstance(e, ast.Call) for e in parsed.elts):
return ExtractedToolCallInformation(
tools_called=True,
tool_calls=[
_handle_single_tool(e) # type: ignore
for e in parsed.elts
],
content=None)
else:
raise _UnexpectedAstError(
"Tool output must be a list of function calls")
except Exception:
logger.exception("Error in extracting tool call from response.")
# Treat as regular text
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> Union[DeltaMessage, None]:
if not current_text.startswith("[") and not current_text.startswith(
"<|python_start|>"):
return DeltaMessage(content=delta_text)
try:
# remove <|python_start|> and <|python_end|>
if current_text.startswith("<|python_start|>"):
current_text = current_text[len("<|python_start|>"):]
if current_text.endswith("<|python_end|>"):
current_text = current_text[:current_text.
rfind("<|python_end|>")]
valid_and_added_text = _make_valid_python(current_text)
if valid_and_added_text is None:
return None
valid_text, added_text = valid_and_added_text
module = ast.parse(valid_text)
parsed = getattr(module.body[0], "value", None)
if not isinstance(parsed, ast.List) or not all(
isinstance(e, ast.Call) for e in parsed.elts):
raise _UnexpectedAstError(
"Tool output must be a list of function calls")
tool_calls = [
_handle_single_tool(e) # type: ignore
for e in parsed.elts
]
tool_deltas = []
for index, new_call in enumerate(tool_calls):
if index < self.current_tool_index:
continue
self.current_tool_index = index
if len(self.streamed_args_for_tool) == index:
self.streamed_args_for_tool.append("")
new_call_complete = index < len(
tool_calls) - 1 or ")]" not in added_text
if new_call_complete:
self.current_tool_index += 1
withheld_suffix = (added_text[:-2]
if not new_call_complete else "")
if not new_call_complete and added_text[-2] == ")":
# Function call is incomplete. Withhold the closing bracket.
withheld_suffix = withheld_suffix + "}"
# Strings get single quotes in the model-produced string.
# JSON requires double quotes.
withheld_suffix = withheld_suffix.replace("'", '"')
delta = _compute_tool_delta(self.streamed_args_for_tool[index],
new_call, index, withheld_suffix)
if delta is not None:
tool_deltas.append(delta)
if (delta.function is not None
and delta.function.arguments is not None):
self.streamed_args_for_tool[
index] += delta.function.arguments
# HACK: serving_chat.py inspects the internal state of tool parsers
# when determining it's final streaming delta, automatically
# adding autocompleted JSON.
# These two lines avoid that nonsense while ensuring finish_reason
# is set to tool_calls when at least one tool is called.
if tool_deltas and not self.prev_tool_call_arr:
self.prev_tool_call_arr = [{"arguments": {}}]
if tool_deltas:
return DeltaMessage(tool_calls=tool_deltas)
elif not added_text and self.current_tool_id > 0:
# Return an empty DeltaMessage once the tool calls are all done
# so that finish_reason gets set.
return DeltaMessage(content='')
else:
return None
except Exception:
logger.exception("Error trying to handle streaming tool call.")
logger.debug(
"Skipping chunk as a result of tool streaming extraction "
"error")
return None
def _get_parameter_value(val: ast.expr) -> Any:
if isinstance(val, ast.Constant):
return val.value
elif isinstance(val, ast.Dict):
if not all(isinstance(k, ast.Constant) for k in val.keys):
raise _UnexpectedAstError(
"Dict tool call arguments must have literal keys")
return {
k.value: _get_parameter_value(v) # type: ignore
for k, v in zip(val.keys, val.values)
}
elif isinstance(val, ast.List):
return [_get_parameter_value(v) for v in val.elts]
else:
raise _UnexpectedAstError("Tool call arguments must be literals")
def _handle_single_tool(call: ast.Call) -> ToolCall:
if not isinstance(call.func, ast.Name):
raise _UnexpectedAstError("Invalid tool call name")
function_name = call.func.id
arguments = {}
for keyword in call.keywords:
arguments[keyword.arg] = _get_parameter_value(keyword.value)
return ToolCall(type="function",
function=FunctionCall(name=function_name,
arguments=json.dumps(arguments)))
def _make_valid_python(text: str) -> Union[tuple[str, str], None]:
bracket_stack = []
for index, char in enumerate(text):
if char in {"[", "(", "{"}:
bracket_stack.append(char)
elif char == "]":
if not bracket_stack or bracket_stack.pop() != "[":
raise _UnexpectedAstError("Mismatched square brackets")
elif char == ")":
if not bracket_stack or bracket_stack.pop() != "(":
raise _UnexpectedAstError("Mismatched parentheses")
elif char == "}":
if not bracket_stack or bracket_stack.pop() != "{":
raise _UnexpectedAstError("Mismatched curly braces")
elif char in {"'", '"'}:
if bracket_stack and bracket_stack[-1] == char:
if index > 0 and text[index - 1] == "\\":
# Treat an escaped quote as a regular character
pass
else:
bracket_stack.pop()
elif bracket_stack and bracket_stack[-1] in {"'", '"'}:
# Double quote within a single quote string or vice versa.
pass
else:
bracket_stack.append(char)
text = text.rstrip()
if text.endswith("=") or text.endswith(":"):
# Since we have no type information for this property/parameter value,
# we can't fill in a valid value.
return None
if bracket_stack and bracket_stack[-1] == "{":
trailing_dict_text = text[:text.rfind("{")]
num_keys = trailing_dict_text.count(":")
num_values = trailing_dict_text.count(",")
if num_keys <= num_values:
return None # Incomplete property name within parameter value
if bracket_stack and bracket_stack[-1] == "(":
trailing_params_text = text[:text.rfind("(")]
num_full_param_names = trailing_params_text.count("=")
num_full_param_values = trailing_params_text.count(",")
if num_full_param_names <= num_full_param_values:
return None # Incomplete parameter name
if text.endswith(","):
text = text[:-1]
if bracket_stack and bracket_stack[-1] == "[" and not text.endswith(
"[") and not text.endswith(")"):
return None # Incomplete function name
added_text = ""
for char in reversed(bracket_stack):
if char == "[":
added_text += "]"
elif char == "(":
added_text += ")"
elif char == "{":
added_text += "}"
elif char == "'":
added_text += "'"
elif char == '"':
added_text += '"'
return text + added_text, added_text
def _compute_tool_delta(previously_sent_args: str, new_call: ToolCall,
index: int,
withheld_suffix: str) -> Union[DeltaToolCall, None]:
new_call_args = new_call.function.arguments
if withheld_suffix:
assert new_call_args.endswith(withheld_suffix)
new_call_args = new_call_args[:-len(withheld_suffix)]
if not previously_sent_args:
return DeltaToolCall(id=new_call.id,
type="function",
index=index,
function=DeltaFunctionCall(
name=new_call.function.name,
arguments=new_call_args,
))
arg_diff = new_call_args[len(previously_sent_args):]
return DeltaToolCall(
id=None, index=index, function=DeltaFunctionCall(
arguments=arg_diff)) if arg_diff else None
# SPDX-License-Identifier: Apache-2.0
import json
import re
from collections.abc import Sequence
from json import JSONDecoder
from typing import Union
import partial_json_parser
import regex as re
from partial_json_parser.core.options import Allow
from transformers import PreTrainedTokenizerBase
......@@ -88,7 +88,8 @@ class Llama3JsonToolParser(ToolParser):
# function call args are JSON but as a string
arguments=json.dumps(raw_function_call["arguments"] \
if "arguments" in raw_function_call \
else raw_function_call["parameters"])))
else raw_function_call["parameters"],
ensure_ascii=False)))
for raw_function_call in function_call_arr
]
......@@ -174,7 +175,8 @@ class Llama3JsonToolParser(ToolParser):
if self.current_tool_id >= 0:
cur_arguments = current_tool_call.get("arguments")
if cur_arguments:
cur_args_json = json.dumps(cur_arguments)
cur_args_json = json.dumps(cur_arguments,
ensure_ascii=False)
sent = len(
self.streamed_args_for_tool[self.current_tool_id])
argument_diff = cur_args_json[sent:]
......@@ -226,7 +228,8 @@ class Llama3JsonToolParser(ToolParser):
if cur_arguments:
sent = len(
self.streamed_args_for_tool[self.current_tool_id])
cur_args_json = json.dumps(cur_arguments)
cur_args_json = json.dumps(cur_arguments,
ensure_ascii=False)
prev_arguments = self.prev_tool_call_arr[
self.current_tool_id].get("arguments")
......@@ -234,7 +237,8 @@ class Llama3JsonToolParser(ToolParser):
if is_complete[self.current_tool_id]:
argument_diff = cur_args_json[sent:]
elif prev_arguments:
prev_args_json = json.dumps(prev_arguments)
prev_args_json = json.dumps(prev_arguments,
ensure_ascii=False)
if cur_args_json != prev_args_json:
prefix = find_common_prefix(
......
# SPDX-License-Identifier: Apache-2.0
import json
import re
from collections.abc import Sequence
from random import choices
from string import ascii_letters, digits
from typing import Union
import partial_json_parser
import regex as re
from partial_json_parser.core.options import Allow
from pydantic import Field
......
# SPDX-License-Identifier: Apache-2.0
import json
import re
from collections.abc import Sequence
from typing import Any, Optional
import regex as re
from transformers import PreTrainedTokenizerBase
from vllm.entrypoints.chat_utils import random_tool_call_id
......@@ -79,10 +79,11 @@ class Phi4MiniJsonToolParser(ToolParser):
name=raw_function_call["name"],
# function call args are JSON but as a string
arguments=json.dumps(
raw_function_call["arguments"] if "arguments" in
raw_function_call else
raw_function_call["parameters"])))
for raw_function_call in function_call_arr
raw_function_call["arguments"]
if "arguments" in raw_function_call else
raw_function_call["parameters"],
ensure_ascii=False),
)) for raw_function_call in function_call_arr
]
# get any content before the tool call
......
......@@ -2,10 +2,10 @@
import ast
import json
import re
from collections.abc import Sequence
from typing import Any, Union
import regex as re
from transformers import PreTrainedTokenizerBase
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
......@@ -200,9 +200,12 @@ def _handle_single_tool(call: ast.Call) -> ToolCall:
arguments = {}
for keyword in call.keywords:
arguments[keyword.arg] = _get_parameter_value(keyword.value)
return ToolCall(type="function",
return ToolCall(
type="function",
function=FunctionCall(name=function_name,
arguments=json.dumps(arguments)))
arguments=json.dumps(arguments,
ensure_ascii=False)),
)
def _make_valid_python(text: str) -> Union[tuple[str, str], None]:
......
......@@ -13,6 +13,13 @@ from vllm.logger import init_logger
logger = init_logger(__name__)
VLLM_SERVE_PARSER_EPILOG = (
"Tip: Use `vllm serve --help=<keyword>` to explore arguments from help.\n"
" - To view a argument group: --help=ModelConfig\n"
" - To view a single argument: --help=max-num-seqs\n"
" - To search by keyword: --help=max\n"
" - To list all groups: --help=listgroup")
async def listen_for_disconnect(request: Request) -> None:
"""Returns if a disconnect message is received"""
......@@ -158,3 +165,55 @@ def _validate_truncation_size(
tokenization_kwargs["max_length"] = truncate_prompt_tokens
return truncate_prompt_tokens
def show_filtered_argument_or_group_from_help(parser):
import sys
for arg in sys.argv:
if arg.startswith('--help='):
search_keyword = arg.split('=', 1)[1]
# List available groups
if search_keyword == 'listgroup':
print("\nAvailable argument groups:")
for group in parser._action_groups:
if group.title and not group.title.startswith(
"positional arguments"):
print(f" - {group.title}")
if group.description:
print(" " + group.description.strip())
print()
sys.exit(0)
# For group search
formatter = parser._get_formatter()
for group in parser._action_groups:
if group.title and group.title.lower() == search_keyword.lower(
):
formatter.start_section(group.title)
formatter.add_text(group.description)
formatter.add_arguments(group._group_actions)
formatter.end_section()
print(formatter.format_help())
sys.exit(0)
# For single arg
matched_actions = []
for group in parser._action_groups:
for action in group._group_actions:
# search option name
if any(search_keyword.lower() in opt.lower()
for opt in action.option_strings):
matched_actions.append(action)
if matched_actions:
print(f"\nParameters matching '{search_keyword}':\n")
formatter = parser._get_formatter()
formatter.add_arguments(matched_actions)
print(formatter.format_help())
sys.exit(0)
print(f"\nNo group or parameter matching '{search_keyword}'")
print("Tip: use `--help=listgroup` to view all groups.")
sys.exit(1)
......@@ -117,6 +117,7 @@ if TYPE_CHECKING:
VLLM_NIXL_SIDE_CHANNEL_HOST: str = "localhost"
VLLM_NIXL_SIDE_CHANNEL_PORT: int = 5557
VLLM_ALL2ALL_BACKEND: str = "naive"
VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE: int = 163840
def get_default_cache_root():
......@@ -163,7 +164,7 @@ def get_vllm_port() -> Optional[int]:
raise ValueError(
f"VLLM_PORT '{port}' appears to be a URI. "
"This may be caused by a Kubernetes service discovery issue"
"check the warning in: https://docs.vllm.ai/en/stable/serving/env_vars.html"
"check the warning in: https://docs.vllm.ai/en/stable/usage/env_vars.html"
)
except Exception:
pass
......@@ -175,7 +176,7 @@ def get_vllm_port() -> Optional[int]:
# The begin-* and end* here are used by the documentation generator
# to extract the used env vars.
# begin-env-vars-definition
# --8<-- [start:env-vars-definition]
environment_variables: dict[str, Callable[[], Any]] = {
......@@ -809,11 +810,21 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda: int(os.getenv("VLLM_NIXL_SIDE_CHANNEL_PORT", "5557")),
# all2all backend for vllm's expert parallel communication
# Available options:
# - "naive": naive all2all implementation using all-reduce
# - "pplx": use pplx kernels
"VLLM_ALL2ALL_BACKEND":
lambda: os.getenv("VLLM_ALL2ALL_BACKEND", "naive"),
# Control the maximum number of tokens per expert supported by the
# NVFP4 MoE CUTLASS Kernel. This value is used to create a buffer for
# the blockscale tensor of activations NVFP4 Quantization.
# This is used to prevent the kernel from running out of memory.
"VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE":
lambda: int(os.getenv("VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE", "163840")),
}
# end-env-vars-definition
# --8<-- [end:env-vars-definition]
def __getattr__(name: str):
......
......@@ -74,7 +74,7 @@ class ExecutorBase(ABC):
`self` argument, in addition to the arguments passed in `args`
and `kwargs`. The `self` argument will be the worker object.
timeout: Maximum time in seconds to wait for execution. Raises a
{exc}`TimeoutError` on timeout. `None` means wait indefinitely.
[`TimeoutError`][] on timeout. `None` means wait indefinitely.
args: Positional arguments to pass to the worker method.
kwargs: Keyword arguments to pass to the worker method.
......
......@@ -528,12 +528,12 @@ class RayDistributedExecutor(DistributedExecutorBase):
ray.get(parallel_worker_tasks)
def _check_ray_cgraph_installation(self):
import pkg_resources
import importlib.metadata
from packaging import version
required_version = version.parse("2.43.0")
current_version = version.parse(
pkg_resources.get_distribution("ray").version)
current_version = version.parse(importlib.metadata.version("ray"))
if current_version < required_version:
raise ValueError(f"Ray version {required_version} is "
f"required, but found {current_version}")
......
......@@ -87,9 +87,8 @@ try:
# TODO(swang): This is needed right now because Ray Compiled Graph
# executes on a background thread, so we need to reset torch's
# current device.
import torch
if not self.compiled_dag_cuda_device_set:
torch.cuda.set_device(self.worker.device)
current_platform.set_device(self.worker.device)
self.compiled_dag_cuda_device_set = True
output = self.worker._execute_model_spmd(execute_model_req,
......@@ -113,8 +112,7 @@ try:
# Not needed
pass
else:
import torch
torch.cuda.set_device(self.worker.device)
current_platform.set_device(self.worker.device)
self.compiled_dag_cuda_device_set = True
......
......@@ -120,7 +120,10 @@ def set_forward_context(attn_metadata: Any,
# we use synchronous scheduling right now,
# adding a sync point here should not affect
# scheduling of the next batch
torch.cuda.synchronize()
from vllm.platforms import current_platform
synchronize = current_platform.synchronize
if synchronize is not None:
synchronize()
now = time.perf_counter()
# time measurement is in milliseconds
batchsize_forward_time[batchsize].append(
......
......@@ -10,8 +10,9 @@ from .registry import (DummyData, InputContext, InputProcessingContext,
INPUT_REGISTRY = InputRegistry()
"""
The global {class}`~InputRegistry` which is used by {class}`~vllm.LLMEngine`
to dispatch data processing according to the target model.
The global [`InputRegistry`][vllm.inputs.registry.InputRegistry] which is used
by [`LLMEngine`][vllm.LLMEngine] to dispatch data processing according to the
target model.
"""
__all__ = [
......
......@@ -80,22 +80,24 @@ SingletonPrompt = Union[str, TextPrompt, TokensPrompt, EmbedsPrompt]
"""
Set of possible schemas for a single prompt:
- A text prompt ({class}`str` or {class}`TextPrompt`)
- A tokenized prompt ({class}`TokensPrompt`)
- An embeddings prompt ({class}`EmbedsPrompt`)
- A text prompt ([`str`][] or [`TextPrompt`][vllm.inputs.data.TextPrompt])
- A tokenized prompt ([`TokensPrompt`][vllm.inputs.data.TokensPrompt])
- An embeddings prompt ([`EmbedsPrompt`][vllm.inputs.data.EmbedsPrompt])
Note that "singleton" is as opposed to a data structure
which encapsulates multiple prompts, i.e. of the sort
which may be utilized for encoder/decoder models when
the user desires to express both the encoder & decoder
prompts explicitly, i.e. {class}`ExplicitEncoderDecoderPrompt`
prompts explicitly, i.e.
[`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt]
A prompt of type {class}`SingletonPrompt` may be employed
as (1) input to a decoder-only model, (2) input to
A prompt of type [`SingletonPrompt`][vllm.inputs.data.SingletonPrompt] may be
employed as (1) input to a decoder-only model, (2) input to
the encoder of an encoder/decoder model, in the scenario
where the decoder-prompt is not specified explicitly, or
(3) as a member of a larger data structure encapsulating
more than one prompt, i.e. {class}`ExplicitEncoderDecoderPrompt`
more than one prompt, i.e.
[`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt]
"""
......@@ -126,18 +128,20 @@ class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]):
comprising an explicit encoder prompt and a decoder prompt.
The encoder and decoder prompts, respectively, may be formatted
according to any of the {class}`SingletonPrompt` schemas,
according to any of the
[`SingletonPrompt`][vllm.inputs.data.SingletonPrompt] schemas,
and are not required to have the same schema.
Only the encoder prompt may have multi-modal data. mm_processor_kwargs
should be at the top-level, and should not be set in the encoder/decoder
prompts, since they are agnostic to the encoder/decoder.
Note that an {class}`ExplicitEncoderDecoderPrompt` may not
be used as an input to a decoder-only model,
Note that an
[`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt]
may not be used as an input to a decoder-only model,
and that the `encoder_prompt` and `decoder_prompt`
fields of this data structure themselves must be
{class}`SingletonPrompt` instances.
[`SingletonPrompt`][vllm.inputs.data.SingletonPrompt] instances.
"""
encoder_prompt: _T1_co
......@@ -152,11 +156,11 @@ PromptType = Union[SingletonPrompt, ExplicitEncoderDecoderPrompt]
Set of possible schemas for an LLM input, including
both decoder-only and encoder/decoder input types:
- A text prompt ({class}`str` or {class}`TextPrompt`)
- A tokenized prompt ({class}`TokensPrompt`)
- An embeddings prompt ({class}`EmbedsPrompt`)
- A text prompt ([`str`][] or [`TextPrompt`][vllm.inputs.data.TextPrompt])
- A tokenized prompt ([`TokensPrompt`][vllm.inputs.data.TokensPrompt])
- An embeddings prompt ([`EmbedsPrompt`][vllm.inputs.data.EmbedsPrompt])
- A single data structure containing both an encoder and a decoder prompt
({class}`ExplicitEncoderDecoderPrompt`)
([`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt])
"""
......@@ -189,7 +193,8 @@ def token_inputs(
prompt: Optional[str] = None,
cache_salt: Optional[str] = None,
) -> TokenInputs:
"""Construct {class}`TokenInputs` from optional values."""
"""Construct [`TokenInputs`][vllm.inputs.data.TokenInputs] from optional
values."""
inputs = TokenInputs(type="token", prompt_token_ids=prompt_token_ids)
if prompt is not None:
......@@ -221,7 +226,8 @@ def embeds_inputs(
prompt_embeds: torch.Tensor,
cache_salt: Optional[str] = None,
) -> EmbedsInputs:
"""Construct :class:`EmbedsInputs` from optional values."""
"""Construct [`EmbedsInputs`][vllm.inputs.data.EmbedsInputs] from optional
values."""
inputs = EmbedsInputs(type="embeds", prompt_embeds=prompt_embeds)
if cache_salt is not None:
......@@ -232,7 +238,7 @@ def embeds_inputs(
DecoderOnlyInputs = Union[TokenInputs, EmbedsInputs, "MultiModalInputs"]
"""
The inputs in {class}`~vllm.LLMEngine` before they are
The inputs in [`LLMEngine`][vllm.engine.llm_engine.LLMEngine] before they are
passed to the model executor.
This specifies the data required for decoder-only models.
"""
......@@ -240,11 +246,12 @@ This specifies the data required for decoder-only models.
class EncoderDecoderInputs(TypedDict):
"""
The inputs in {class}`~vllm.LLMEngine` before they are
passed to the model executor.
The inputs in [`LLMEngine`][vllm.engine.llm_engine.LLMEngine] before they
are passed to the model executor.
This specifies the required data for encoder-decoder models.
"""
encoder: Union[TokenInputs, "MultiModalInputs"]
"""The inputs for the encoder portion."""
......@@ -254,13 +261,13 @@ class EncoderDecoderInputs(TypedDict):
SingletonInputs = Union[TokenInputs, EmbedsInputs, "MultiModalInputs"]
"""
A processed {class}`SingletonPrompt` which can be passed to
{class}`vllm.sequence.Sequence`.
A processed [`SingletonPrompt`][vllm.inputs.data.SingletonPrompt] which can be
passed to [`vllm.sequence.Sequence`][].
"""
ProcessorInputs = Union[DecoderOnlyInputs, EncoderDecoderInputs]
"""
The inputs to {data}`vllm.inputs.InputProcessor`.
The outputs from [`vllm.inputs.preprocess.InputPreprocessor`][].
"""
_T1 = TypeVar("_T1", bound=SingletonPrompt, default=SingletonPrompt)
......@@ -277,7 +284,8 @@ def build_explicit_enc_dec_prompt(
return ExplicitEncoderDecoderPrompt(
encoder_prompt=encoder_prompt,
decoder_prompt=decoder_prompt,
mm_processor_kwargs=mm_processor_kwargs)
mm_processor_kwargs=mm_processor_kwargs,
)
def zip_enc_dec_prompts(
......@@ -288,7 +296,8 @@ def zip_enc_dec_prompts(
) -> list[ExplicitEncoderDecoderPrompt[_T1, _T2]]:
"""
Zip encoder and decoder prompts together into a list of
{class}`ExplicitEncoderDecoderPrompt` instances.
[`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt]
instances.
``mm_processor_kwargs`` may also be provided; if a dict is passed, the same
dictionary will be used for every encoder/decoder prompt. If an iterable is
......@@ -299,9 +308,10 @@ def zip_enc_dec_prompts(
if isinstance(mm_processor_kwargs, dict):
return [
build_explicit_enc_dec_prompt(
encoder_prompt, decoder_prompt,
cast(dict[str, Any], mm_processor_kwargs))
for (encoder_prompt,
encoder_prompt,
decoder_prompt,
cast(dict[str, Any], mm_processor_kwargs),
) for (encoder_prompt,
decoder_prompt) in zip(enc_prompts, dec_prompts)
]
return [
......
......@@ -23,13 +23,13 @@ class ParsedTokens(TypedDict):
@overload
def parse_and_batch_prompt(
prompt: Union[str, list[str]]) -> Sequence[ParsedText]:
prompt: Union[str, list[str]], ) -> Sequence[ParsedText]:
...
@overload
def parse_and_batch_prompt(
prompt: Union[list[int], list[list[int]]]) -> Sequence[ParsedTokens]:
prompt: Union[list[int], list[list[int]]], ) -> Sequence[ParsedTokens]:
...
......@@ -86,7 +86,7 @@ class ParsedTokensPrompt(TypedDict):
class ParsedEmbedsPrompt(TypedDict):
type: Literal['embeds']
type: Literal["embeds"]
content: EmbedsPrompt
......@@ -133,7 +133,7 @@ def parse_singleton_prompt(prompt: SingletonPrompt) -> ParsedSingletonPrompt:
def is_explicit_encoder_decoder_prompt(
prompt: PromptType) -> TypeIs[ExplicitEncoderDecoderPrompt]:
prompt: PromptType, ) -> TypeIs[ExplicitEncoderDecoderPrompt]:
return isinstance(prompt, dict) and "encoder_prompt" in prompt
......
......@@ -67,11 +67,11 @@ class InputPreprocessor:
return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id
def get_decoder_start_token_id(self) -> Optional[int]:
'''
"""
Obtain the decoder start token id employed by an encoder/decoder
model. Returns None for non-encoder/decoder models or if the
model config is unavailable.
'''
"""
if not self.model_config.is_encoder_decoder:
logger.warning_once(
......@@ -79,14 +79,14 @@ class InputPreprocessor:
"this is not an encoder/decoder model.")
return None
if (self.model_config is None or self.model_config.hf_config is None):
if self.model_config is None or self.model_config.hf_config is None:
logger.warning_once(
"Using None for decoder start token id because "
"model config is not available.")
return None
dec_start_token_id = getattr(self.model_config.hf_config,
'decoder_start_token_id', None)
"decoder_start_token_id", None)
if dec_start_token_id is None:
logger.warning_once(
"Falling back on <BOS> for decoder start token "
......@@ -97,7 +97,7 @@ class InputPreprocessor:
return dec_start_token_id
def _get_default_enc_dec_decoder_prompt(self) -> list[int]:
'''
"""
Specifically for encoder/decoder models:
generate a default decoder prompt for when
the user specifies only the encoder prompt.
......@@ -126,7 +126,7 @@ class InputPreprocessor:
Returns:
* prompt_token_ids
'''
"""
bos_token_id = self.get_bos_token_id()
assert bos_token_id is not None
......@@ -224,7 +224,10 @@ class InputPreprocessor:
lora_request: Optional[LoRARequest],
tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> list[int]:
"""Async version of {meth}`_tokenize_prompt`."""
"""
Async version of
[`_tokenize_prompt`][vllm.inputs.preprocess.InputPreprocessor._tokenize_prompt].
"""
tokenizer = self.get_tokenizer_group()
tokenization_kwargs = self._get_tokenization_kw(tokenization_kwargs)
......@@ -287,7 +290,10 @@ class InputPreprocessor:
lora_request: Optional[LoRARequest],
return_mm_hashes: bool = False,
) -> MultiModalInputs:
"""Async version of {meth}`_process_multimodal`."""
"""
Async version of
[`_process_multimodal`][vllm.inputs.preprocess.InputPreprocessor._process_multimodal].
"""
tokenizer = await self._get_mm_tokenizer_async(lora_request)
mm_processor = self.mm_registry.create_processor(self.model_config,
......@@ -472,7 +478,7 @@ class InputPreprocessor:
Returns:
* {class}`SingletonInputs` instance
* [`SingletonInputs`][vllm.inputs.data.SingletonInputs] instance
"""
parsed = parse_singleton_prompt(prompt)
......@@ -508,7 +514,10 @@ class InputPreprocessor:
lora_request: Optional[LoRARequest] = None,
return_mm_hashes: bool = False,
) -> SingletonInputs:
"""Async version of {meth}`_prompt_to_llm_inputs`."""
"""
Async version of
[`_prompt_to_llm_inputs`][vllm.inputs.preprocess.InputPreprocessor._prompt_to_llm_inputs].
"""
parsed = parse_singleton_prompt(prompt)
if parsed["type"] == "embeds":
......@@ -644,7 +653,9 @@ class InputPreprocessor:
) -> EncoderDecoderInputs:
"""
For encoder/decoder models only:
Process an input prompt into an {class}`EncoderDecoderInputs` instance.
Process an input prompt into an
[`EncoderDecoderInputs`][vllm.inputs.data.EncoderDecoderInputs]
instance.
There are two types of input prompts:
singleton prompts which carry only the
......@@ -670,7 +681,8 @@ class InputPreprocessor:
Returns:
* {class}`EncoderDecoderInputs` instance
* [`EncoderDecoderInputs`][vllm.inputs.data.EncoderDecoderInputs]
instance
"""
encoder_inputs: SingletonInputs
decoder_inputs: Optional[SingletonInputs]
......@@ -710,7 +722,10 @@ class InputPreprocessor:
prompt: PromptType,
tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> EncoderDecoderInputs:
"""Async version of {meth}`_process_encoder_decoder_prompt`."""
"""
Async version of
[`_process_encoder_decoder_prompt`][vllm.inputs.preprocess.InputPreprocessor._process_encoder_decoder_prompt].
"""
encoder_inputs: SingletonInputs
decoder_inputs: Optional[SingletonInputs]
......@@ -778,7 +793,8 @@ class InputPreprocessor:
) -> DecoderOnlyInputs:
"""
For decoder-only models:
Process an input prompt into an {class}`DecoderOnlyInputs` instance.
Process an input prompt into a
[`DecoderOnlyInputs`][vllm.inputs.data.DecoderOnlyInputs] instance.
Arguments:
......@@ -789,7 +805,7 @@ class InputPreprocessor:
Returns:
* {class}`DecoderOnlyInputs` instance
* [`DecoderOnlyInputs`][vllm.inputs.data.DecoderOnlyInputs] instance
"""
prompt_comps = self._prompt_to_llm_inputs(
......@@ -812,7 +828,10 @@ class InputPreprocessor:
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
return_mm_hashes: bool = False,
) -> DecoderOnlyInputs:
"""Async version of {meth}`_process_decoder_only_prompt`."""
"""
Async version of
[`_process_decoder_only_prompt`][vllm.inputs.preprocess.InputPreprocessor._process_decoder_only_prompt].
"""
prompt_comps = await self._prompt_to_llm_inputs_async(
prompt,
tokenization_kwargs=tokenization_kwargs,
......@@ -863,7 +882,10 @@ class InputPreprocessor:
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
return_mm_hashes: bool = False,
) -> ProcessorInputs:
"""Async version of {meth}`preprocess`."""
"""
Async version of
[`preprocess`][vllm.inputs.preprocess.InputPreprocessor.preprocess].
"""
if self.model_config.is_encoder_decoder:
assert not return_mm_hashes, (
"Multimodal hashes for encoder-decoder models should not be ",
......
......@@ -38,7 +38,7 @@ class InputContext:
) -> _C:
"""
Get the HuggingFace configuration
({class}`transformers.PretrainedConfig`) of the model,
(`transformers.PretrainedConfig`) of the model,
additionally checking its type.
Raises:
......@@ -79,7 +79,7 @@ class InputContext:
) -> _P:
"""
Get the HuggingFace processor
({class}`transformers.ProcessorMixin`) of the model,
(`transformers.ProcessorMixin`) of the model,
additionally checking its type.
Raises:
......
......@@ -68,22 +68,22 @@ class _VllmLogger(Logger):
"""
Note:
This class is just to provide type information.
We actually patch the methods directly on the {class}`logging.Logger`
We actually patch the methods directly on the [`logging.Logger`][]
instance to avoid conflicting with other libraries such as
`intel_extension_for_pytorch.utils._logger`.
"""
def info_once(self, msg: str, *args: Hashable) -> None:
"""
As {meth}`info`, but subsequent calls with the same message
are silently dropped.
As [`info`][logging.Logger.info], but subsequent calls with
the same message are silently dropped.
"""
_print_info_once(self, msg, *args)
def warning_once(self, msg: str, *args: Hashable) -> None:
"""
As {meth}`warning`, but subsequent calls with the same message
are silently dropped.
As [`warning`][logging.Logger.warning], but subsequent calls with
the same message are silently dropped.
"""
_print_warning_once(self, msg, *args)
......
......@@ -18,7 +18,7 @@ logger = init_logger(__name__)
def prepare_object_to_dump(obj) -> str:
if isinstance(obj, str):
return "'{obj}'" # Double quotes
return f"'{obj}'" # Double quotes
elif isinstance(obj, dict):
dict_str = ', '.join({f'{str(k)}: {prepare_object_to_dump(v)}' \
for k, v in obj.items()})
......@@ -42,9 +42,9 @@ def prepare_object_to_dump(obj) -> str:
return obj.anon_repr()
elif hasattr(obj, '__dict__'):
items = obj.__dict__.items()
dict_str = ','.join([f'{str(k)}={prepare_object_to_dump(v)}' \
dict_str = ', '.join([f'{str(k)}={prepare_object_to_dump(v)}' \
for k, v in items])
return (f"{type(obj).__name__}({dict_str})")
return f"{type(obj).__name__}({dict_str})"
else:
# Hacky way to make sure we can serialize the object in JSON format
try:
......
......@@ -3,11 +3,11 @@
import copy
import math
import os
import re
from collections.abc import Sequence
from dataclasses import dataclass, field
from typing import Any, Callable, Optional, Union
import regex as re
import safetensors.torch
import torch
from torch import nn
......@@ -29,6 +29,7 @@ from vllm.lora.utils import (from_layer, from_layer_logits_processor,
get_supported_lora_modules,
is_regex_target_modules,
parse_fine_tuned_lora_name, replace_submodule)
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.model_executor.models import SupportsLoRA, supports_multimodal
from vllm.model_executor.models.interfaces import is_pooling_model
from vllm.model_executor.models.module_mapping import MultiModelKeys
......@@ -197,7 +198,7 @@ class LoRAModel(AdapterModel):
embedding_modules: Optional[dict[str, str]] = None,
embedding_padding_modules: Optional[list[str]] = None,
weights_mapper: Optional[WeightsMapper] = None,
) -> "LoRAModel":
tensorizer_config_dict: Optional[dict] = None) -> "LoRAModel":
"""Create a LoRAModel from a local checkpoint.
Args:
......@@ -219,20 +220,11 @@ class LoRAModel(AdapterModel):
lora_dir, "new_embeddings.safetensors")
new_embeddings_bin_file_path = os.path.join(lora_dir,
"new_embeddings.bin")
unexpected_modules: list[Union[list[str], str]]
if os.path.isfile(lora_tensor_path):
tensors: dict[str, torch.Tensor] = {}
# Find unexpected modules.
# Use safetensor key as a source of truth to find expected modules.
# in peft if you have target_modules A, B, C and C does not exist
# in the model it won’t error and model will be trained with A, B
# loraified. C won’t exist in the safetensor but it will exist in
# the target_modules of the adapter_config.json.
unexpected_modules = []
with safetensors.safe_open(lora_tensor_path,
framework="pt") as f: # type: ignore
for lora_module in f.keys(): # noqa
unexpected_modules: list[Union[list[str], str]] = []
def check_unexpected_modules(modules: dict):
for lora_module in modules.keys(): # noqa
module_name, _, _ = parse_fine_tuned_lora_name(
lora_module, weights_mapper)
part_name = module_name.split(".")[-1]
......@@ -243,9 +235,32 @@ class LoRAModel(AdapterModel):
f"While loading {lora_dir}, expected"
f" target modules in {expected_lora_modules}"
f" but received {unexpected_modules}."
f" Please verify that the loaded LoRA module is correct"
)
f" Please verify that the loaded LoRA module is correct")
if tensorizer_config_dict:
from tensorizer import TensorDeserializer
tensorizer_config = TensorizerConfig(**tensorizer_config_dict)
lora_tensor_path = os.path.join(tensorizer_config.tensorizer_dir,
"adapter_model.tensors")
tensorizer_args = tensorizer_config._construct_tensorizer_args()
tensors = TensorDeserializer(lora_tensor_path,
dtype=tensorizer_config.dtype,
**tensorizer_args.deserializer_params)
check_unexpected_modules(tensors)
elif os.path.isfile(lora_tensor_path):
# Find unexpected modules.
# Use safetensor key as a source of truth to find expected modules.
# in peft if you have target_modules A, B, C and C does not exist
# in the model it won’t error and model will be trained with A, B
# loraified. C won’t exist in the safetensor but it will exist in
# the target_modules of the adapter_config.json.
unexpected_modules = []
with safetensors.safe_open(lora_tensor_path,
framework="pt") as f: # type: ignore
# Load tensors if there are only expected modules.
check_unexpected_modules(f)
for module in f.keys(): # noqa
tensors[module] = f.get_tensor(module)
elif os.path.isfile(lora_bin_file_path):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment