Commit 4eabe123 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge remote-tracking branch 'mirror/releases/v0.9.0' into v0.9.0-ori

parents 45840cd2 58738772
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import json import json
import re
from collections.abc import Sequence from collections.abc import Sequence
from typing import Union from typing import Union
import partial_json_parser import partial_json_parser
import regex as re
from partial_json_parser.core.options import Allow from partial_json_parser.core.options import Allow
from vllm.entrypoints.chat_utils import random_tool_call_id from vllm.entrypoints.chat_utils import random_tool_call_id
...@@ -96,8 +96,9 @@ class JambaToolParser(ToolParser): ...@@ -96,8 +96,9 @@ class JambaToolParser(ToolParser):
function=FunctionCall( function=FunctionCall(
name=function_call["name"], name=function_call["name"],
# function call args are JSON but as a string # function call args are JSON but as a string
arguments=json.dumps(function_call["arguments"]))) arguments=json.dumps(function_call["arguments"],
for function_call in raw_function_calls ensure_ascii=False),
)) for function_call in raw_function_calls
] ]
content = model_output[:model_output. content = model_output[:model_output.
...@@ -187,7 +188,7 @@ class JambaToolParser(ToolParser): ...@@ -187,7 +188,7 @@ class JambaToolParser(ToolParser):
diff: Union[str, None] = current_tool_call.get("arguments") diff: Union[str, None] = current_tool_call.get("arguments")
if diff: if diff:
diff = json.dumps(diff).replace( diff = json.dumps(diff, ensure_ascii=False).replace(
self.streamed_args_for_tool[self.current_tool_id], self.streamed_args_for_tool[self.current_tool_id],
"") "")
delta = DeltaMessage(tool_calls=[ delta = DeltaMessage(tool_calls=[
...@@ -248,7 +249,8 @@ class JambaToolParser(ToolParser): ...@@ -248,7 +249,8 @@ class JambaToolParser(ToolParser):
"mid-arguments") "mid-arguments")
delta = None delta = None
elif cur_arguments and not prev_arguments: elif cur_arguments and not prev_arguments:
cur_arguments_json = json.dumps(cur_arguments) cur_arguments_json = json.dumps(cur_arguments,
ensure_ascii=False)
logger.debug("finding %s in %s", new_text, logger.debug("finding %s in %s", new_text,
cur_arguments_json) cur_arguments_json)
...@@ -267,8 +269,10 @@ class JambaToolParser(ToolParser): ...@@ -267,8 +269,10 @@ class JambaToolParser(ToolParser):
self.current_tool_id] += arguments_delta self.current_tool_id] += arguments_delta
elif cur_arguments and prev_arguments: elif cur_arguments and prev_arguments:
cur_args_json = json.dumps(cur_arguments) cur_args_json = json.dumps(cur_arguments,
prev_args_json = json.dumps(prev_arguments) ensure_ascii=False)
prev_args_json = json.dumps(prev_arguments,
ensure_ascii=False)
logger.debug("Searching for diff between \n%s\n%s", logger.debug("Searching for diff between \n%s\n%s",
cur_args_json, prev_args_json) cur_args_json, prev_args_json)
......
# SPDX-License-Identifier: Apache-2.0
import ast
import json
from collections.abc import Sequence
from typing import Any, Union
import regex as re
from transformers import PreTrainedTokenizerBase
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaFunctionCall, DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall, ToolCall)
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser, ToolParserManager)
from vllm.logger import init_logger
logger = init_logger(__name__)
class _UnexpectedAstError(Exception):
pass
@ToolParserManager.register_module("llama4_pythonic")
class Llama4PythonicToolParser(ToolParser):
"""
Toolcall parser for Llama4 that produce tool calls in a pythonic style
Use --enable-auto-tool-choice --tool-call-parser llama4_pythonic
"""
# TODO(mdepinet): Possible future improvements:
# 1. Support text + tools separated by either <|python_tag|> or \n\n
# 2. Support tools outside of a list (or separated by a semicolon).
# This depends on item 1 for consistent streaming.
# Neither of these are necessary for e.g. ToolACE, but both would help make
# Llama3.2 models more reliable.
TOOL_CALL_REGEX = re.compile(
r"\[([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s)?\),\s*)*([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s*)?\)\s*)+\]",
re.DOTALL)
def __init__(self, tokenizer: PreTrainedTokenizerBase):
super().__init__(tokenizer)
# Rename for readability. This is NOT a tool id.
@property
def current_tool_index(self) -> int:
return self.current_tool_id
@current_tool_index.setter
def current_tool_index(self, value: int) -> None:
self.current_tool_id = value
def extract_tool_calls(
self, model_output: str,
request: ChatCompletionRequest) -> ExtractedToolCallInformation:
"""
Extract the tool calls from a complete model response.
"""
# remove <|python_start|> and <|python_end|>
# as Llama 4 model sometime will output those tokens
if model_output.startswith("<|python_start|>"):
model_output = model_output[len("<|python_start|>"):]
model_output = model_output.replace("<|python_end|>", "")
if not (self.TOOL_CALL_REGEX.match(model_output)):
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)
try:
module = ast.parse(model_output)
parsed = getattr(module.body[0], "value", None)
if isinstance(parsed, ast.List) and all(
isinstance(e, ast.Call) for e in parsed.elts):
return ExtractedToolCallInformation(
tools_called=True,
tool_calls=[
_handle_single_tool(e) # type: ignore
for e in parsed.elts
],
content=None)
else:
raise _UnexpectedAstError(
"Tool output must be a list of function calls")
except Exception:
logger.exception("Error in extracting tool call from response.")
# Treat as regular text
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> Union[DeltaMessage, None]:
if not current_text.startswith("[") and not current_text.startswith(
"<|python_start|>"):
return DeltaMessage(content=delta_text)
try:
# remove <|python_start|> and <|python_end|>
if current_text.startswith("<|python_start|>"):
current_text = current_text[len("<|python_start|>"):]
if current_text.endswith("<|python_end|>"):
current_text = current_text[:current_text.
rfind("<|python_end|>")]
valid_and_added_text = _make_valid_python(current_text)
if valid_and_added_text is None:
return None
valid_text, added_text = valid_and_added_text
module = ast.parse(valid_text)
parsed = getattr(module.body[0], "value", None)
if not isinstance(parsed, ast.List) or not all(
isinstance(e, ast.Call) for e in parsed.elts):
raise _UnexpectedAstError(
"Tool output must be a list of function calls")
tool_calls = [
_handle_single_tool(e) # type: ignore
for e in parsed.elts
]
tool_deltas = []
for index, new_call in enumerate(tool_calls):
if index < self.current_tool_index:
continue
self.current_tool_index = index
if len(self.streamed_args_for_tool) == index:
self.streamed_args_for_tool.append("")
new_call_complete = index < len(
tool_calls) - 1 or ")]" not in added_text
if new_call_complete:
self.current_tool_index += 1
withheld_suffix = (added_text[:-2]
if not new_call_complete else "")
if not new_call_complete and added_text[-2] == ")":
# Function call is incomplete. Withhold the closing bracket.
withheld_suffix = withheld_suffix + "}"
# Strings get single quotes in the model-produced string.
# JSON requires double quotes.
withheld_suffix = withheld_suffix.replace("'", '"')
delta = _compute_tool_delta(self.streamed_args_for_tool[index],
new_call, index, withheld_suffix)
if delta is not None:
tool_deltas.append(delta)
if (delta.function is not None
and delta.function.arguments is not None):
self.streamed_args_for_tool[
index] += delta.function.arguments
# HACK: serving_chat.py inspects the internal state of tool parsers
# when determining it's final streaming delta, automatically
# adding autocompleted JSON.
# These two lines avoid that nonsense while ensuring finish_reason
# is set to tool_calls when at least one tool is called.
if tool_deltas and not self.prev_tool_call_arr:
self.prev_tool_call_arr = [{"arguments": {}}]
if tool_deltas:
return DeltaMessage(tool_calls=tool_deltas)
elif not added_text and self.current_tool_id > 0:
# Return an empty DeltaMessage once the tool calls are all done
# so that finish_reason gets set.
return DeltaMessage(content='')
else:
return None
except Exception:
logger.exception("Error trying to handle streaming tool call.")
logger.debug(
"Skipping chunk as a result of tool streaming extraction "
"error")
return None
def _get_parameter_value(val: ast.expr) -> Any:
if isinstance(val, ast.Constant):
return val.value
elif isinstance(val, ast.Dict):
if not all(isinstance(k, ast.Constant) for k in val.keys):
raise _UnexpectedAstError(
"Dict tool call arguments must have literal keys")
return {
k.value: _get_parameter_value(v) # type: ignore
for k, v in zip(val.keys, val.values)
}
elif isinstance(val, ast.List):
return [_get_parameter_value(v) for v in val.elts]
else:
raise _UnexpectedAstError("Tool call arguments must be literals")
def _handle_single_tool(call: ast.Call) -> ToolCall:
if not isinstance(call.func, ast.Name):
raise _UnexpectedAstError("Invalid tool call name")
function_name = call.func.id
arguments = {}
for keyword in call.keywords:
arguments[keyword.arg] = _get_parameter_value(keyword.value)
return ToolCall(type="function",
function=FunctionCall(name=function_name,
arguments=json.dumps(arguments)))
def _make_valid_python(text: str) -> Union[tuple[str, str], None]:
bracket_stack = []
for index, char in enumerate(text):
if char in {"[", "(", "{"}:
bracket_stack.append(char)
elif char == "]":
if not bracket_stack or bracket_stack.pop() != "[":
raise _UnexpectedAstError("Mismatched square brackets")
elif char == ")":
if not bracket_stack or bracket_stack.pop() != "(":
raise _UnexpectedAstError("Mismatched parentheses")
elif char == "}":
if not bracket_stack or bracket_stack.pop() != "{":
raise _UnexpectedAstError("Mismatched curly braces")
elif char in {"'", '"'}:
if bracket_stack and bracket_stack[-1] == char:
if index > 0 and text[index - 1] == "\\":
# Treat an escaped quote as a regular character
pass
else:
bracket_stack.pop()
elif bracket_stack and bracket_stack[-1] in {"'", '"'}:
# Double quote within a single quote string or vice versa.
pass
else:
bracket_stack.append(char)
text = text.rstrip()
if text.endswith("=") or text.endswith(":"):
# Since we have no type information for this property/parameter value,
# we can't fill in a valid value.
return None
if bracket_stack and bracket_stack[-1] == "{":
trailing_dict_text = text[:text.rfind("{")]
num_keys = trailing_dict_text.count(":")
num_values = trailing_dict_text.count(",")
if num_keys <= num_values:
return None # Incomplete property name within parameter value
if bracket_stack and bracket_stack[-1] == "(":
trailing_params_text = text[:text.rfind("(")]
num_full_param_names = trailing_params_text.count("=")
num_full_param_values = trailing_params_text.count(",")
if num_full_param_names <= num_full_param_values:
return None # Incomplete parameter name
if text.endswith(","):
text = text[:-1]
if bracket_stack and bracket_stack[-1] == "[" and not text.endswith(
"[") and not text.endswith(")"):
return None # Incomplete function name
added_text = ""
for char in reversed(bracket_stack):
if char == "[":
added_text += "]"
elif char == "(":
added_text += ")"
elif char == "{":
added_text += "}"
elif char == "'":
added_text += "'"
elif char == '"':
added_text += '"'
return text + added_text, added_text
def _compute_tool_delta(previously_sent_args: str, new_call: ToolCall,
index: int,
withheld_suffix: str) -> Union[DeltaToolCall, None]:
new_call_args = new_call.function.arguments
if withheld_suffix:
assert new_call_args.endswith(withheld_suffix)
new_call_args = new_call_args[:-len(withheld_suffix)]
if not previously_sent_args:
return DeltaToolCall(id=new_call.id,
type="function",
index=index,
function=DeltaFunctionCall(
name=new_call.function.name,
arguments=new_call_args,
))
arg_diff = new_call_args[len(previously_sent_args):]
return DeltaToolCall(
id=None, index=index, function=DeltaFunctionCall(
arguments=arg_diff)) if arg_diff else None
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import json import json
import re
from collections.abc import Sequence from collections.abc import Sequence
from json import JSONDecoder from json import JSONDecoder
from typing import Union from typing import Union
import partial_json_parser import partial_json_parser
import regex as re
from partial_json_parser.core.options import Allow from partial_json_parser.core.options import Allow
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
...@@ -88,7 +88,8 @@ class Llama3JsonToolParser(ToolParser): ...@@ -88,7 +88,8 @@ class Llama3JsonToolParser(ToolParser):
# function call args are JSON but as a string # function call args are JSON but as a string
arguments=json.dumps(raw_function_call["arguments"] \ arguments=json.dumps(raw_function_call["arguments"] \
if "arguments" in raw_function_call \ if "arguments" in raw_function_call \
else raw_function_call["parameters"]))) else raw_function_call["parameters"],
ensure_ascii=False)))
for raw_function_call in function_call_arr for raw_function_call in function_call_arr
] ]
...@@ -174,7 +175,8 @@ class Llama3JsonToolParser(ToolParser): ...@@ -174,7 +175,8 @@ class Llama3JsonToolParser(ToolParser):
if self.current_tool_id >= 0: if self.current_tool_id >= 0:
cur_arguments = current_tool_call.get("arguments") cur_arguments = current_tool_call.get("arguments")
if cur_arguments: if cur_arguments:
cur_args_json = json.dumps(cur_arguments) cur_args_json = json.dumps(cur_arguments,
ensure_ascii=False)
sent = len( sent = len(
self.streamed_args_for_tool[self.current_tool_id]) self.streamed_args_for_tool[self.current_tool_id])
argument_diff = cur_args_json[sent:] argument_diff = cur_args_json[sent:]
...@@ -226,7 +228,8 @@ class Llama3JsonToolParser(ToolParser): ...@@ -226,7 +228,8 @@ class Llama3JsonToolParser(ToolParser):
if cur_arguments: if cur_arguments:
sent = len( sent = len(
self.streamed_args_for_tool[self.current_tool_id]) self.streamed_args_for_tool[self.current_tool_id])
cur_args_json = json.dumps(cur_arguments) cur_args_json = json.dumps(cur_arguments,
ensure_ascii=False)
prev_arguments = self.prev_tool_call_arr[ prev_arguments = self.prev_tool_call_arr[
self.current_tool_id].get("arguments") self.current_tool_id].get("arguments")
...@@ -234,7 +237,8 @@ class Llama3JsonToolParser(ToolParser): ...@@ -234,7 +237,8 @@ class Llama3JsonToolParser(ToolParser):
if is_complete[self.current_tool_id]: if is_complete[self.current_tool_id]:
argument_diff = cur_args_json[sent:] argument_diff = cur_args_json[sent:]
elif prev_arguments: elif prev_arguments:
prev_args_json = json.dumps(prev_arguments) prev_args_json = json.dumps(prev_arguments,
ensure_ascii=False)
if cur_args_json != prev_args_json: if cur_args_json != prev_args_json:
prefix = find_common_prefix( prefix = find_common_prefix(
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import json import json
import re
from collections.abc import Sequence from collections.abc import Sequence
from random import choices from random import choices
from string import ascii_letters, digits from string import ascii_letters, digits
from typing import Union from typing import Union
import partial_json_parser import partial_json_parser
import regex as re
from partial_json_parser.core.options import Allow from partial_json_parser.core.options import Allow
from pydantic import Field from pydantic import Field
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import json import json
import re
from collections.abc import Sequence from collections.abc import Sequence
from typing import Any, Optional from typing import Any, Optional
import regex as re
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from vllm.entrypoints.chat_utils import random_tool_call_id from vllm.entrypoints.chat_utils import random_tool_call_id
...@@ -79,10 +79,11 @@ class Phi4MiniJsonToolParser(ToolParser): ...@@ -79,10 +79,11 @@ class Phi4MiniJsonToolParser(ToolParser):
name=raw_function_call["name"], name=raw_function_call["name"],
# function call args are JSON but as a string # function call args are JSON but as a string
arguments=json.dumps( arguments=json.dumps(
raw_function_call["arguments"] if "arguments" in raw_function_call["arguments"]
raw_function_call else if "arguments" in raw_function_call else
raw_function_call["parameters"]))) raw_function_call["parameters"],
for raw_function_call in function_call_arr ensure_ascii=False),
)) for raw_function_call in function_call_arr
] ]
# get any content before the tool call # get any content before the tool call
......
...@@ -2,10 +2,10 @@ ...@@ -2,10 +2,10 @@
import ast import ast
import json import json
import re
from collections.abc import Sequence from collections.abc import Sequence
from typing import Any, Union from typing import Any, Union
import regex as re
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
...@@ -200,9 +200,12 @@ def _handle_single_tool(call: ast.Call) -> ToolCall: ...@@ -200,9 +200,12 @@ def _handle_single_tool(call: ast.Call) -> ToolCall:
arguments = {} arguments = {}
for keyword in call.keywords: for keyword in call.keywords:
arguments[keyword.arg] = _get_parameter_value(keyword.value) arguments[keyword.arg] = _get_parameter_value(keyword.value)
return ToolCall(type="function", return ToolCall(
function=FunctionCall(name=function_name, type="function",
arguments=json.dumps(arguments))) function=FunctionCall(name=function_name,
arguments=json.dumps(arguments,
ensure_ascii=False)),
)
def _make_valid_python(text: str) -> Union[tuple[str, str], None]: def _make_valid_python(text: str) -> Union[tuple[str, str], None]:
......
...@@ -13,6 +13,13 @@ from vllm.logger import init_logger ...@@ -13,6 +13,13 @@ from vllm.logger import init_logger
logger = init_logger(__name__) logger = init_logger(__name__)
VLLM_SERVE_PARSER_EPILOG = (
"Tip: Use `vllm serve --help=<keyword>` to explore arguments from help.\n"
" - To view a argument group: --help=ModelConfig\n"
" - To view a single argument: --help=max-num-seqs\n"
" - To search by keyword: --help=max\n"
" - To list all groups: --help=listgroup")
async def listen_for_disconnect(request: Request) -> None: async def listen_for_disconnect(request: Request) -> None:
"""Returns if a disconnect message is received""" """Returns if a disconnect message is received"""
...@@ -158,3 +165,55 @@ def _validate_truncation_size( ...@@ -158,3 +165,55 @@ def _validate_truncation_size(
tokenization_kwargs["max_length"] = truncate_prompt_tokens tokenization_kwargs["max_length"] = truncate_prompt_tokens
return truncate_prompt_tokens return truncate_prompt_tokens
def show_filtered_argument_or_group_from_help(parser):
import sys
for arg in sys.argv:
if arg.startswith('--help='):
search_keyword = arg.split('=', 1)[1]
# List available groups
if search_keyword == 'listgroup':
print("\nAvailable argument groups:")
for group in parser._action_groups:
if group.title and not group.title.startswith(
"positional arguments"):
print(f" - {group.title}")
if group.description:
print(" " + group.description.strip())
print()
sys.exit(0)
# For group search
formatter = parser._get_formatter()
for group in parser._action_groups:
if group.title and group.title.lower() == search_keyword.lower(
):
formatter.start_section(group.title)
formatter.add_text(group.description)
formatter.add_arguments(group._group_actions)
formatter.end_section()
print(formatter.format_help())
sys.exit(0)
# For single arg
matched_actions = []
for group in parser._action_groups:
for action in group._group_actions:
# search option name
if any(search_keyword.lower() in opt.lower()
for opt in action.option_strings):
matched_actions.append(action)
if matched_actions:
print(f"\nParameters matching '{search_keyword}':\n")
formatter = parser._get_formatter()
formatter.add_arguments(matched_actions)
print(formatter.format_help())
sys.exit(0)
print(f"\nNo group or parameter matching '{search_keyword}'")
print("Tip: use `--help=listgroup` to view all groups.")
sys.exit(1)
...@@ -117,6 +117,7 @@ if TYPE_CHECKING: ...@@ -117,6 +117,7 @@ if TYPE_CHECKING:
VLLM_NIXL_SIDE_CHANNEL_HOST: str = "localhost" VLLM_NIXL_SIDE_CHANNEL_HOST: str = "localhost"
VLLM_NIXL_SIDE_CHANNEL_PORT: int = 5557 VLLM_NIXL_SIDE_CHANNEL_PORT: int = 5557
VLLM_ALL2ALL_BACKEND: str = "naive" VLLM_ALL2ALL_BACKEND: str = "naive"
VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE: int = 163840
def get_default_cache_root(): def get_default_cache_root():
...@@ -163,7 +164,7 @@ def get_vllm_port() -> Optional[int]: ...@@ -163,7 +164,7 @@ def get_vllm_port() -> Optional[int]:
raise ValueError( raise ValueError(
f"VLLM_PORT '{port}' appears to be a URI. " f"VLLM_PORT '{port}' appears to be a URI. "
"This may be caused by a Kubernetes service discovery issue" "This may be caused by a Kubernetes service discovery issue"
"check the warning in: https://docs.vllm.ai/en/stable/serving/env_vars.html" "check the warning in: https://docs.vllm.ai/en/stable/usage/env_vars.html"
) )
except Exception: except Exception:
pass pass
...@@ -175,7 +176,7 @@ def get_vllm_port() -> Optional[int]: ...@@ -175,7 +176,7 @@ def get_vllm_port() -> Optional[int]:
# The begin-* and end* here are used by the documentation generator # The begin-* and end* here are used by the documentation generator
# to extract the used env vars. # to extract the used env vars.
# begin-env-vars-definition # --8<-- [start:env-vars-definition]
environment_variables: dict[str, Callable[[], Any]] = { environment_variables: dict[str, Callable[[], Any]] = {
...@@ -809,11 +810,21 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -809,11 +810,21 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda: int(os.getenv("VLLM_NIXL_SIDE_CHANNEL_PORT", "5557")), lambda: int(os.getenv("VLLM_NIXL_SIDE_CHANNEL_PORT", "5557")),
# all2all backend for vllm's expert parallel communication # all2all backend for vllm's expert parallel communication
# Available options:
# - "naive": naive all2all implementation using all-reduce
# - "pplx": use pplx kernels
"VLLM_ALL2ALL_BACKEND": "VLLM_ALL2ALL_BACKEND":
lambda: os.getenv("VLLM_ALL2ALL_BACKEND", "naive"), lambda: os.getenv("VLLM_ALL2ALL_BACKEND", "naive"),
# Control the maximum number of tokens per expert supported by the
# NVFP4 MoE CUTLASS Kernel. This value is used to create a buffer for
# the blockscale tensor of activations NVFP4 Quantization.
# This is used to prevent the kernel from running out of memory.
"VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE":
lambda: int(os.getenv("VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE", "163840")),
} }
# end-env-vars-definition # --8<-- [end:env-vars-definition]
def __getattr__(name: str): def __getattr__(name: str):
......
...@@ -74,7 +74,7 @@ class ExecutorBase(ABC): ...@@ -74,7 +74,7 @@ class ExecutorBase(ABC):
`self` argument, in addition to the arguments passed in `args` `self` argument, in addition to the arguments passed in `args`
and `kwargs`. The `self` argument will be the worker object. and `kwargs`. The `self` argument will be the worker object.
timeout: Maximum time in seconds to wait for execution. Raises a timeout: Maximum time in seconds to wait for execution. Raises a
{exc}`TimeoutError` on timeout. `None` means wait indefinitely. [`TimeoutError`][] on timeout. `None` means wait indefinitely.
args: Positional arguments to pass to the worker method. args: Positional arguments to pass to the worker method.
kwargs: Keyword arguments to pass to the worker method. kwargs: Keyword arguments to pass to the worker method.
......
...@@ -528,12 +528,12 @@ class RayDistributedExecutor(DistributedExecutorBase): ...@@ -528,12 +528,12 @@ class RayDistributedExecutor(DistributedExecutorBase):
ray.get(parallel_worker_tasks) ray.get(parallel_worker_tasks)
def _check_ray_cgraph_installation(self): def _check_ray_cgraph_installation(self):
import pkg_resources import importlib.metadata
from packaging import version from packaging import version
required_version = version.parse("2.43.0") required_version = version.parse("2.43.0")
current_version = version.parse( current_version = version.parse(importlib.metadata.version("ray"))
pkg_resources.get_distribution("ray").version)
if current_version < required_version: if current_version < required_version:
raise ValueError(f"Ray version {required_version} is " raise ValueError(f"Ray version {required_version} is "
f"required, but found {current_version}") f"required, but found {current_version}")
......
...@@ -87,9 +87,8 @@ try: ...@@ -87,9 +87,8 @@ try:
# TODO(swang): This is needed right now because Ray Compiled Graph # TODO(swang): This is needed right now because Ray Compiled Graph
# executes on a background thread, so we need to reset torch's # executes on a background thread, so we need to reset torch's
# current device. # current device.
import torch
if not self.compiled_dag_cuda_device_set: if not self.compiled_dag_cuda_device_set:
torch.cuda.set_device(self.worker.device) current_platform.set_device(self.worker.device)
self.compiled_dag_cuda_device_set = True self.compiled_dag_cuda_device_set = True
output = self.worker._execute_model_spmd(execute_model_req, output = self.worker._execute_model_spmd(execute_model_req,
...@@ -113,8 +112,7 @@ try: ...@@ -113,8 +112,7 @@ try:
# Not needed # Not needed
pass pass
else: else:
import torch current_platform.set_device(self.worker.device)
torch.cuda.set_device(self.worker.device)
self.compiled_dag_cuda_device_set = True self.compiled_dag_cuda_device_set = True
......
...@@ -120,7 +120,10 @@ def set_forward_context(attn_metadata: Any, ...@@ -120,7 +120,10 @@ def set_forward_context(attn_metadata: Any,
# we use synchronous scheduling right now, # we use synchronous scheduling right now,
# adding a sync point here should not affect # adding a sync point here should not affect
# scheduling of the next batch # scheduling of the next batch
torch.cuda.synchronize() from vllm.platforms import current_platform
synchronize = current_platform.synchronize
if synchronize is not None:
synchronize()
now = time.perf_counter() now = time.perf_counter()
# time measurement is in milliseconds # time measurement is in milliseconds
batchsize_forward_time[batchsize].append( batchsize_forward_time[batchsize].append(
......
...@@ -10,8 +10,9 @@ from .registry import (DummyData, InputContext, InputProcessingContext, ...@@ -10,8 +10,9 @@ from .registry import (DummyData, InputContext, InputProcessingContext,
INPUT_REGISTRY = InputRegistry() INPUT_REGISTRY = InputRegistry()
""" """
The global {class}`~InputRegistry` which is used by {class}`~vllm.LLMEngine` The global [`InputRegistry`][vllm.inputs.registry.InputRegistry] which is used
to dispatch data processing according to the target model. by [`LLMEngine`][vllm.LLMEngine] to dispatch data processing according to the
target model.
""" """
__all__ = [ __all__ = [
......
...@@ -80,22 +80,24 @@ SingletonPrompt = Union[str, TextPrompt, TokensPrompt, EmbedsPrompt] ...@@ -80,22 +80,24 @@ SingletonPrompt = Union[str, TextPrompt, TokensPrompt, EmbedsPrompt]
""" """
Set of possible schemas for a single prompt: Set of possible schemas for a single prompt:
- A text prompt ({class}`str` or {class}`TextPrompt`) - A text prompt ([`str`][] or [`TextPrompt`][vllm.inputs.data.TextPrompt])
- A tokenized prompt ({class}`TokensPrompt`) - A tokenized prompt ([`TokensPrompt`][vllm.inputs.data.TokensPrompt])
- An embeddings prompt ({class}`EmbedsPrompt`) - An embeddings prompt ([`EmbedsPrompt`][vllm.inputs.data.EmbedsPrompt])
Note that "singleton" is as opposed to a data structure Note that "singleton" is as opposed to a data structure
which encapsulates multiple prompts, i.e. of the sort which encapsulates multiple prompts, i.e. of the sort
which may be utilized for encoder/decoder models when which may be utilized for encoder/decoder models when
the user desires to express both the encoder & decoder the user desires to express both the encoder & decoder
prompts explicitly, i.e. {class}`ExplicitEncoderDecoderPrompt` prompts explicitly, i.e.
[`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt]
A prompt of type {class}`SingletonPrompt` may be employed A prompt of type [`SingletonPrompt`][vllm.inputs.data.SingletonPrompt] may be
as (1) input to a decoder-only model, (2) input to employed as (1) input to a decoder-only model, (2) input to
the encoder of an encoder/decoder model, in the scenario the encoder of an encoder/decoder model, in the scenario
where the decoder-prompt is not specified explicitly, or where the decoder-prompt is not specified explicitly, or
(3) as a member of a larger data structure encapsulating (3) as a member of a larger data structure encapsulating
more than one prompt, i.e. {class}`ExplicitEncoderDecoderPrompt` more than one prompt, i.e.
[`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt]
""" """
...@@ -126,18 +128,20 @@ class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]): ...@@ -126,18 +128,20 @@ class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]):
comprising an explicit encoder prompt and a decoder prompt. comprising an explicit encoder prompt and a decoder prompt.
The encoder and decoder prompts, respectively, may be formatted The encoder and decoder prompts, respectively, may be formatted
according to any of the {class}`SingletonPrompt` schemas, according to any of the
[`SingletonPrompt`][vllm.inputs.data.SingletonPrompt] schemas,
and are not required to have the same schema. and are not required to have the same schema.
Only the encoder prompt may have multi-modal data. mm_processor_kwargs Only the encoder prompt may have multi-modal data. mm_processor_kwargs
should be at the top-level, and should not be set in the encoder/decoder should be at the top-level, and should not be set in the encoder/decoder
prompts, since they are agnostic to the encoder/decoder. prompts, since they are agnostic to the encoder/decoder.
Note that an {class}`ExplicitEncoderDecoderPrompt` may not Note that an
be used as an input to a decoder-only model, [`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt]
may not be used as an input to a decoder-only model,
and that the `encoder_prompt` and `decoder_prompt` and that the `encoder_prompt` and `decoder_prompt`
fields of this data structure themselves must be fields of this data structure themselves must be
{class}`SingletonPrompt` instances. [`SingletonPrompt`][vllm.inputs.data.SingletonPrompt] instances.
""" """
encoder_prompt: _T1_co encoder_prompt: _T1_co
...@@ -152,11 +156,11 @@ PromptType = Union[SingletonPrompt, ExplicitEncoderDecoderPrompt] ...@@ -152,11 +156,11 @@ PromptType = Union[SingletonPrompt, ExplicitEncoderDecoderPrompt]
Set of possible schemas for an LLM input, including Set of possible schemas for an LLM input, including
both decoder-only and encoder/decoder input types: both decoder-only and encoder/decoder input types:
- A text prompt ({class}`str` or {class}`TextPrompt`) - A text prompt ([`str`][] or [`TextPrompt`][vllm.inputs.data.TextPrompt])
- A tokenized prompt ({class}`TokensPrompt`) - A tokenized prompt ([`TokensPrompt`][vllm.inputs.data.TokensPrompt])
- An embeddings prompt ({class}`EmbedsPrompt`) - An embeddings prompt ([`EmbedsPrompt`][vllm.inputs.data.EmbedsPrompt])
- A single data structure containing both an encoder and a decoder prompt - A single data structure containing both an encoder and a decoder prompt
({class}`ExplicitEncoderDecoderPrompt`) ([`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt])
""" """
...@@ -189,7 +193,8 @@ def token_inputs( ...@@ -189,7 +193,8 @@ def token_inputs(
prompt: Optional[str] = None, prompt: Optional[str] = None,
cache_salt: Optional[str] = None, cache_salt: Optional[str] = None,
) -> TokenInputs: ) -> TokenInputs:
"""Construct {class}`TokenInputs` from optional values.""" """Construct [`TokenInputs`][vllm.inputs.data.TokenInputs] from optional
values."""
inputs = TokenInputs(type="token", prompt_token_ids=prompt_token_ids) inputs = TokenInputs(type="token", prompt_token_ids=prompt_token_ids)
if prompt is not None: if prompt is not None:
...@@ -221,7 +226,8 @@ def embeds_inputs( ...@@ -221,7 +226,8 @@ def embeds_inputs(
prompt_embeds: torch.Tensor, prompt_embeds: torch.Tensor,
cache_salt: Optional[str] = None, cache_salt: Optional[str] = None,
) -> EmbedsInputs: ) -> EmbedsInputs:
"""Construct :class:`EmbedsInputs` from optional values.""" """Construct [`EmbedsInputs`][vllm.inputs.data.EmbedsInputs] from optional
values."""
inputs = EmbedsInputs(type="embeds", prompt_embeds=prompt_embeds) inputs = EmbedsInputs(type="embeds", prompt_embeds=prompt_embeds)
if cache_salt is not None: if cache_salt is not None:
...@@ -232,7 +238,7 @@ def embeds_inputs( ...@@ -232,7 +238,7 @@ def embeds_inputs(
DecoderOnlyInputs = Union[TokenInputs, EmbedsInputs, "MultiModalInputs"] DecoderOnlyInputs = Union[TokenInputs, EmbedsInputs, "MultiModalInputs"]
""" """
The inputs in {class}`~vllm.LLMEngine` before they are The inputs in [`LLMEngine`][vllm.engine.llm_engine.LLMEngine] before they are
passed to the model executor. passed to the model executor.
This specifies the data required for decoder-only models. This specifies the data required for decoder-only models.
""" """
...@@ -240,11 +246,12 @@ This specifies the data required for decoder-only models. ...@@ -240,11 +246,12 @@ This specifies the data required for decoder-only models.
class EncoderDecoderInputs(TypedDict): class EncoderDecoderInputs(TypedDict):
""" """
The inputs in {class}`~vllm.LLMEngine` before they are The inputs in [`LLMEngine`][vllm.engine.llm_engine.LLMEngine] before they
passed to the model executor. are passed to the model executor.
This specifies the required data for encoder-decoder models. This specifies the required data for encoder-decoder models.
""" """
encoder: Union[TokenInputs, "MultiModalInputs"] encoder: Union[TokenInputs, "MultiModalInputs"]
"""The inputs for the encoder portion.""" """The inputs for the encoder portion."""
...@@ -254,13 +261,13 @@ class EncoderDecoderInputs(TypedDict): ...@@ -254,13 +261,13 @@ class EncoderDecoderInputs(TypedDict):
SingletonInputs = Union[TokenInputs, EmbedsInputs, "MultiModalInputs"] SingletonInputs = Union[TokenInputs, EmbedsInputs, "MultiModalInputs"]
""" """
A processed {class}`SingletonPrompt` which can be passed to A processed [`SingletonPrompt`][vllm.inputs.data.SingletonPrompt] which can be
{class}`vllm.sequence.Sequence`. passed to [`vllm.sequence.Sequence`][].
""" """
ProcessorInputs = Union[DecoderOnlyInputs, EncoderDecoderInputs] ProcessorInputs = Union[DecoderOnlyInputs, EncoderDecoderInputs]
""" """
The inputs to {data}`vllm.inputs.InputProcessor`. The outputs from [`vllm.inputs.preprocess.InputPreprocessor`][].
""" """
_T1 = TypeVar("_T1", bound=SingletonPrompt, default=SingletonPrompt) _T1 = TypeVar("_T1", bound=SingletonPrompt, default=SingletonPrompt)
...@@ -277,7 +284,8 @@ def build_explicit_enc_dec_prompt( ...@@ -277,7 +284,8 @@ def build_explicit_enc_dec_prompt(
return ExplicitEncoderDecoderPrompt( return ExplicitEncoderDecoderPrompt(
encoder_prompt=encoder_prompt, encoder_prompt=encoder_prompt,
decoder_prompt=decoder_prompt, decoder_prompt=decoder_prompt,
mm_processor_kwargs=mm_processor_kwargs) mm_processor_kwargs=mm_processor_kwargs,
)
def zip_enc_dec_prompts( def zip_enc_dec_prompts(
...@@ -288,7 +296,8 @@ def zip_enc_dec_prompts( ...@@ -288,7 +296,8 @@ def zip_enc_dec_prompts(
) -> list[ExplicitEncoderDecoderPrompt[_T1, _T2]]: ) -> list[ExplicitEncoderDecoderPrompt[_T1, _T2]]:
""" """
Zip encoder and decoder prompts together into a list of Zip encoder and decoder prompts together into a list of
{class}`ExplicitEncoderDecoderPrompt` instances. [`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt]
instances.
``mm_processor_kwargs`` may also be provided; if a dict is passed, the same ``mm_processor_kwargs`` may also be provided; if a dict is passed, the same
dictionary will be used for every encoder/decoder prompt. If an iterable is dictionary will be used for every encoder/decoder prompt. If an iterable is
...@@ -299,10 +308,11 @@ def zip_enc_dec_prompts( ...@@ -299,10 +308,11 @@ def zip_enc_dec_prompts(
if isinstance(mm_processor_kwargs, dict): if isinstance(mm_processor_kwargs, dict):
return [ return [
build_explicit_enc_dec_prompt( build_explicit_enc_dec_prompt(
encoder_prompt, decoder_prompt, encoder_prompt,
cast(dict[str, Any], mm_processor_kwargs)) decoder_prompt,
for (encoder_prompt, cast(dict[str, Any], mm_processor_kwargs),
decoder_prompt) in zip(enc_prompts, dec_prompts) ) for (encoder_prompt,
decoder_prompt) in zip(enc_prompts, dec_prompts)
] ]
return [ return [
build_explicit_enc_dec_prompt(encoder_prompt, decoder_prompt, build_explicit_enc_dec_prompt(encoder_prompt, decoder_prompt,
......
...@@ -23,13 +23,13 @@ class ParsedTokens(TypedDict): ...@@ -23,13 +23,13 @@ class ParsedTokens(TypedDict):
@overload @overload
def parse_and_batch_prompt( def parse_and_batch_prompt(
prompt: Union[str, list[str]]) -> Sequence[ParsedText]: prompt: Union[str, list[str]], ) -> Sequence[ParsedText]:
... ...
@overload @overload
def parse_and_batch_prompt( def parse_and_batch_prompt(
prompt: Union[list[int], list[list[int]]]) -> Sequence[ParsedTokens]: prompt: Union[list[int], list[list[int]]], ) -> Sequence[ParsedTokens]:
... ...
...@@ -86,7 +86,7 @@ class ParsedTokensPrompt(TypedDict): ...@@ -86,7 +86,7 @@ class ParsedTokensPrompt(TypedDict):
class ParsedEmbedsPrompt(TypedDict): class ParsedEmbedsPrompt(TypedDict):
type: Literal['embeds'] type: Literal["embeds"]
content: EmbedsPrompt content: EmbedsPrompt
...@@ -133,7 +133,7 @@ def parse_singleton_prompt(prompt: SingletonPrompt) -> ParsedSingletonPrompt: ...@@ -133,7 +133,7 @@ def parse_singleton_prompt(prompt: SingletonPrompt) -> ParsedSingletonPrompt:
def is_explicit_encoder_decoder_prompt( def is_explicit_encoder_decoder_prompt(
prompt: PromptType) -> TypeIs[ExplicitEncoderDecoderPrompt]: prompt: PromptType, ) -> TypeIs[ExplicitEncoderDecoderPrompt]:
return isinstance(prompt, dict) and "encoder_prompt" in prompt return isinstance(prompt, dict) and "encoder_prompt" in prompt
......
...@@ -67,11 +67,11 @@ class InputPreprocessor: ...@@ -67,11 +67,11 @@ class InputPreprocessor:
return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id
def get_decoder_start_token_id(self) -> Optional[int]: def get_decoder_start_token_id(self) -> Optional[int]:
''' """
Obtain the decoder start token id employed by an encoder/decoder Obtain the decoder start token id employed by an encoder/decoder
model. Returns None for non-encoder/decoder models or if the model. Returns None for non-encoder/decoder models or if the
model config is unavailable. model config is unavailable.
''' """
if not self.model_config.is_encoder_decoder: if not self.model_config.is_encoder_decoder:
logger.warning_once( logger.warning_once(
...@@ -79,14 +79,14 @@ class InputPreprocessor: ...@@ -79,14 +79,14 @@ class InputPreprocessor:
"this is not an encoder/decoder model.") "this is not an encoder/decoder model.")
return None return None
if (self.model_config is None or self.model_config.hf_config is None): if self.model_config is None or self.model_config.hf_config is None:
logger.warning_once( logger.warning_once(
"Using None for decoder start token id because " "Using None for decoder start token id because "
"model config is not available.") "model config is not available.")
return None return None
dec_start_token_id = getattr(self.model_config.hf_config, dec_start_token_id = getattr(self.model_config.hf_config,
'decoder_start_token_id', None) "decoder_start_token_id", None)
if dec_start_token_id is None: if dec_start_token_id is None:
logger.warning_once( logger.warning_once(
"Falling back on <BOS> for decoder start token " "Falling back on <BOS> for decoder start token "
...@@ -97,7 +97,7 @@ class InputPreprocessor: ...@@ -97,7 +97,7 @@ class InputPreprocessor:
return dec_start_token_id return dec_start_token_id
def _get_default_enc_dec_decoder_prompt(self) -> list[int]: def _get_default_enc_dec_decoder_prompt(self) -> list[int]:
''' """
Specifically for encoder/decoder models: Specifically for encoder/decoder models:
generate a default decoder prompt for when generate a default decoder prompt for when
the user specifies only the encoder prompt. the user specifies only the encoder prompt.
...@@ -126,7 +126,7 @@ class InputPreprocessor: ...@@ -126,7 +126,7 @@ class InputPreprocessor:
Returns: Returns:
* prompt_token_ids * prompt_token_ids
''' """
bos_token_id = self.get_bos_token_id() bos_token_id = self.get_bos_token_id()
assert bos_token_id is not None assert bos_token_id is not None
...@@ -224,7 +224,10 @@ class InputPreprocessor: ...@@ -224,7 +224,10 @@ class InputPreprocessor:
lora_request: Optional[LoRARequest], lora_request: Optional[LoRARequest],
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> list[int]: ) -> list[int]:
"""Async version of {meth}`_tokenize_prompt`.""" """
Async version of
[`_tokenize_prompt`][vllm.inputs.preprocess.InputPreprocessor._tokenize_prompt].
"""
tokenizer = self.get_tokenizer_group() tokenizer = self.get_tokenizer_group()
tokenization_kwargs = self._get_tokenization_kw(tokenization_kwargs) tokenization_kwargs = self._get_tokenization_kw(tokenization_kwargs)
...@@ -287,7 +290,10 @@ class InputPreprocessor: ...@@ -287,7 +290,10 @@ class InputPreprocessor:
lora_request: Optional[LoRARequest], lora_request: Optional[LoRARequest],
return_mm_hashes: bool = False, return_mm_hashes: bool = False,
) -> MultiModalInputs: ) -> MultiModalInputs:
"""Async version of {meth}`_process_multimodal`.""" """
Async version of
[`_process_multimodal`][vllm.inputs.preprocess.InputPreprocessor._process_multimodal].
"""
tokenizer = await self._get_mm_tokenizer_async(lora_request) tokenizer = await self._get_mm_tokenizer_async(lora_request)
mm_processor = self.mm_registry.create_processor(self.model_config, mm_processor = self.mm_registry.create_processor(self.model_config,
...@@ -472,7 +478,7 @@ class InputPreprocessor: ...@@ -472,7 +478,7 @@ class InputPreprocessor:
Returns: Returns:
* {class}`SingletonInputs` instance * [`SingletonInputs`][vllm.inputs.data.SingletonInputs] instance
""" """
parsed = parse_singleton_prompt(prompt) parsed = parse_singleton_prompt(prompt)
...@@ -508,7 +514,10 @@ class InputPreprocessor: ...@@ -508,7 +514,10 @@ class InputPreprocessor:
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
return_mm_hashes: bool = False, return_mm_hashes: bool = False,
) -> SingletonInputs: ) -> SingletonInputs:
"""Async version of {meth}`_prompt_to_llm_inputs`.""" """
Async version of
[`_prompt_to_llm_inputs`][vllm.inputs.preprocess.InputPreprocessor._prompt_to_llm_inputs].
"""
parsed = parse_singleton_prompt(prompt) parsed = parse_singleton_prompt(prompt)
if parsed["type"] == "embeds": if parsed["type"] == "embeds":
...@@ -644,7 +653,9 @@ class InputPreprocessor: ...@@ -644,7 +653,9 @@ class InputPreprocessor:
) -> EncoderDecoderInputs: ) -> EncoderDecoderInputs:
""" """
For encoder/decoder models only: For encoder/decoder models only:
Process an input prompt into an {class}`EncoderDecoderInputs` instance. Process an input prompt into an
[`EncoderDecoderInputs`][vllm.inputs.data.EncoderDecoderInputs]
instance.
There are two types of input prompts: There are two types of input prompts:
singleton prompts which carry only the singleton prompts which carry only the
...@@ -670,7 +681,8 @@ class InputPreprocessor: ...@@ -670,7 +681,8 @@ class InputPreprocessor:
Returns: Returns:
* {class}`EncoderDecoderInputs` instance * [`EncoderDecoderInputs`][vllm.inputs.data.EncoderDecoderInputs]
instance
""" """
encoder_inputs: SingletonInputs encoder_inputs: SingletonInputs
decoder_inputs: Optional[SingletonInputs] decoder_inputs: Optional[SingletonInputs]
...@@ -710,7 +722,10 @@ class InputPreprocessor: ...@@ -710,7 +722,10 @@ class InputPreprocessor:
prompt: PromptType, prompt: PromptType,
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> EncoderDecoderInputs: ) -> EncoderDecoderInputs:
"""Async version of {meth}`_process_encoder_decoder_prompt`.""" """
Async version of
[`_process_encoder_decoder_prompt`][vllm.inputs.preprocess.InputPreprocessor._process_encoder_decoder_prompt].
"""
encoder_inputs: SingletonInputs encoder_inputs: SingletonInputs
decoder_inputs: Optional[SingletonInputs] decoder_inputs: Optional[SingletonInputs]
...@@ -778,7 +793,8 @@ class InputPreprocessor: ...@@ -778,7 +793,8 @@ class InputPreprocessor:
) -> DecoderOnlyInputs: ) -> DecoderOnlyInputs:
""" """
For decoder-only models: For decoder-only models:
Process an input prompt into an {class}`DecoderOnlyInputs` instance. Process an input prompt into a
[`DecoderOnlyInputs`][vllm.inputs.data.DecoderOnlyInputs] instance.
Arguments: Arguments:
...@@ -789,7 +805,7 @@ class InputPreprocessor: ...@@ -789,7 +805,7 @@ class InputPreprocessor:
Returns: Returns:
* {class}`DecoderOnlyInputs` instance * [`DecoderOnlyInputs`][vllm.inputs.data.DecoderOnlyInputs] instance
""" """
prompt_comps = self._prompt_to_llm_inputs( prompt_comps = self._prompt_to_llm_inputs(
...@@ -812,7 +828,10 @@ class InputPreprocessor: ...@@ -812,7 +828,10 @@ class InputPreprocessor:
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
return_mm_hashes: bool = False, return_mm_hashes: bool = False,
) -> DecoderOnlyInputs: ) -> DecoderOnlyInputs:
"""Async version of {meth}`_process_decoder_only_prompt`.""" """
Async version of
[`_process_decoder_only_prompt`][vllm.inputs.preprocess.InputPreprocessor._process_decoder_only_prompt].
"""
prompt_comps = await self._prompt_to_llm_inputs_async( prompt_comps = await self._prompt_to_llm_inputs_async(
prompt, prompt,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
...@@ -863,7 +882,10 @@ class InputPreprocessor: ...@@ -863,7 +882,10 @@ class InputPreprocessor:
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
return_mm_hashes: bool = False, return_mm_hashes: bool = False,
) -> ProcessorInputs: ) -> ProcessorInputs:
"""Async version of {meth}`preprocess`.""" """
Async version of
[`preprocess`][vllm.inputs.preprocess.InputPreprocessor.preprocess].
"""
if self.model_config.is_encoder_decoder: if self.model_config.is_encoder_decoder:
assert not return_mm_hashes, ( assert not return_mm_hashes, (
"Multimodal hashes for encoder-decoder models should not be ", "Multimodal hashes for encoder-decoder models should not be ",
......
...@@ -38,7 +38,7 @@ class InputContext: ...@@ -38,7 +38,7 @@ class InputContext:
) -> _C: ) -> _C:
""" """
Get the HuggingFace configuration Get the HuggingFace configuration
({class}`transformers.PretrainedConfig`) of the model, (`transformers.PretrainedConfig`) of the model,
additionally checking its type. additionally checking its type.
Raises: Raises:
...@@ -79,7 +79,7 @@ class InputContext: ...@@ -79,7 +79,7 @@ class InputContext:
) -> _P: ) -> _P:
""" """
Get the HuggingFace processor Get the HuggingFace processor
({class}`transformers.ProcessorMixin`) of the model, (`transformers.ProcessorMixin`) of the model,
additionally checking its type. additionally checking its type.
Raises: Raises:
......
...@@ -68,22 +68,22 @@ class _VllmLogger(Logger): ...@@ -68,22 +68,22 @@ class _VllmLogger(Logger):
""" """
Note: Note:
This class is just to provide type information. This class is just to provide type information.
We actually patch the methods directly on the {class}`logging.Logger` We actually patch the methods directly on the [`logging.Logger`][]
instance to avoid conflicting with other libraries such as instance to avoid conflicting with other libraries such as
`intel_extension_for_pytorch.utils._logger`. `intel_extension_for_pytorch.utils._logger`.
""" """
def info_once(self, msg: str, *args: Hashable) -> None: def info_once(self, msg: str, *args: Hashable) -> None:
""" """
As {meth}`info`, but subsequent calls with the same message As [`info`][logging.Logger.info], but subsequent calls with
are silently dropped. the same message are silently dropped.
""" """
_print_info_once(self, msg, *args) _print_info_once(self, msg, *args)
def warning_once(self, msg: str, *args: Hashable) -> None: def warning_once(self, msg: str, *args: Hashable) -> None:
""" """
As {meth}`warning`, but subsequent calls with the same message As [`warning`][logging.Logger.warning], but subsequent calls with
are silently dropped. the same message are silently dropped.
""" """
_print_warning_once(self, msg, *args) _print_warning_once(self, msg, *args)
......
...@@ -18,7 +18,7 @@ logger = init_logger(__name__) ...@@ -18,7 +18,7 @@ logger = init_logger(__name__)
def prepare_object_to_dump(obj) -> str: def prepare_object_to_dump(obj) -> str:
if isinstance(obj, str): if isinstance(obj, str):
return "'{obj}'" # Double quotes return f"'{obj}'" # Double quotes
elif isinstance(obj, dict): elif isinstance(obj, dict):
dict_str = ', '.join({f'{str(k)}: {prepare_object_to_dump(v)}' \ dict_str = ', '.join({f'{str(k)}: {prepare_object_to_dump(v)}' \
for k, v in obj.items()}) for k, v in obj.items()})
...@@ -42,9 +42,9 @@ def prepare_object_to_dump(obj) -> str: ...@@ -42,9 +42,9 @@ def prepare_object_to_dump(obj) -> str:
return obj.anon_repr() return obj.anon_repr()
elif hasattr(obj, '__dict__'): elif hasattr(obj, '__dict__'):
items = obj.__dict__.items() items = obj.__dict__.items()
dict_str = ','.join([f'{str(k)}={prepare_object_to_dump(v)}' \ dict_str = ', '.join([f'{str(k)}={prepare_object_to_dump(v)}' \
for k, v in items]) for k, v in items])
return (f"{type(obj).__name__}({dict_str})") return f"{type(obj).__name__}({dict_str})"
else: else:
# Hacky way to make sure we can serialize the object in JSON format # Hacky way to make sure we can serialize the object in JSON format
try: try:
......
...@@ -3,11 +3,11 @@ ...@@ -3,11 +3,11 @@
import copy import copy
import math import math
import os import os
import re
from collections.abc import Sequence from collections.abc import Sequence
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Callable, Optional, Union from typing import Any, Callable, Optional, Union
import regex as re
import safetensors.torch import safetensors.torch
import torch import torch
from torch import nn from torch import nn
...@@ -29,6 +29,7 @@ from vllm.lora.utils import (from_layer, from_layer_logits_processor, ...@@ -29,6 +29,7 @@ from vllm.lora.utils import (from_layer, from_layer_logits_processor,
get_supported_lora_modules, get_supported_lora_modules,
is_regex_target_modules, is_regex_target_modules,
parse_fine_tuned_lora_name, replace_submodule) parse_fine_tuned_lora_name, replace_submodule)
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.model_executor.models import SupportsLoRA, supports_multimodal from vllm.model_executor.models import SupportsLoRA, supports_multimodal
from vllm.model_executor.models.interfaces import is_pooling_model from vllm.model_executor.models.interfaces import is_pooling_model
from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.module_mapping import MultiModelKeys
...@@ -185,19 +186,19 @@ class LoRAModel(AdapterModel): ...@@ -185,19 +186,19 @@ class LoRAModel(AdapterModel):
@classmethod @classmethod
def from_local_checkpoint( def from_local_checkpoint(
cls, cls,
lora_dir: str, lora_dir: str,
expected_lora_modules: list[str], expected_lora_modules: list[str],
peft_helper: PEFTHelper, peft_helper: PEFTHelper,
*, *,
lora_model_id: Optional[int] = None, lora_model_id: Optional[int] = None,
device: str = "cuda", device: str = "cuda",
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
target_embedding_padding: Optional[int] = None, target_embedding_padding: Optional[int] = None,
embedding_modules: Optional[dict[str, str]] = None, embedding_modules: Optional[dict[str, str]] = None,
embedding_padding_modules: Optional[list[str]] = None, embedding_padding_modules: Optional[list[str]] = None,
weights_mapper: Optional[WeightsMapper] = None, weights_mapper: Optional[WeightsMapper] = None,
) -> "LoRAModel": tensorizer_config_dict: Optional[dict] = None) -> "LoRAModel":
"""Create a LoRAModel from a local checkpoint. """Create a LoRAModel from a local checkpoint.
Args: Args:
...@@ -219,10 +220,36 @@ class LoRAModel(AdapterModel): ...@@ -219,10 +220,36 @@ class LoRAModel(AdapterModel):
lora_dir, "new_embeddings.safetensors") lora_dir, "new_embeddings.safetensors")
new_embeddings_bin_file_path = os.path.join(lora_dir, new_embeddings_bin_file_path = os.path.join(lora_dir,
"new_embeddings.bin") "new_embeddings.bin")
tensors: dict[str, torch.Tensor] = {}
unexpected_modules: list[Union[list[str], str]] = []
def check_unexpected_modules(modules: dict):
for lora_module in modules.keys(): # noqa
module_name, _, _ = parse_fine_tuned_lora_name(
lora_module, weights_mapper)
part_name = module_name.split(".")[-1]
if part_name not in expected_lora_modules:
unexpected_modules.append(module_name)
if unexpected_modules:
raise ValueError(
f"While loading {lora_dir}, expected"
f" target modules in {expected_lora_modules}"
f" but received {unexpected_modules}."
f" Please verify that the loaded LoRA module is correct")
unexpected_modules: list[Union[list[str], str]] if tensorizer_config_dict:
if os.path.isfile(lora_tensor_path): from tensorizer import TensorDeserializer
tensors: dict[str, torch.Tensor] = {}
tensorizer_config = TensorizerConfig(**tensorizer_config_dict)
lora_tensor_path = os.path.join(tensorizer_config.tensorizer_dir,
"adapter_model.tensors")
tensorizer_args = tensorizer_config._construct_tensorizer_args()
tensors = TensorDeserializer(lora_tensor_path,
dtype=tensorizer_config.dtype,
**tensorizer_args.deserializer_params)
check_unexpected_modules(tensors)
elif os.path.isfile(lora_tensor_path):
# Find unexpected modules. # Find unexpected modules.
# Use safetensor key as a source of truth to find expected modules. # Use safetensor key as a source of truth to find expected modules.
# in peft if you have target_modules A, B, C and C does not exist # in peft if you have target_modules A, B, C and C does not exist
...@@ -232,20 +259,8 @@ class LoRAModel(AdapterModel): ...@@ -232,20 +259,8 @@ class LoRAModel(AdapterModel):
unexpected_modules = [] unexpected_modules = []
with safetensors.safe_open(lora_tensor_path, with safetensors.safe_open(lora_tensor_path,
framework="pt") as f: # type: ignore framework="pt") as f: # type: ignore
for lora_module in f.keys(): # noqa
module_name, _, _ = parse_fine_tuned_lora_name(
lora_module, weights_mapper)
part_name = module_name.split(".")[-1]
if part_name not in expected_lora_modules:
unexpected_modules.append(module_name)
if unexpected_modules:
raise ValueError(
f"While loading {lora_dir}, expected"
f" target modules in {expected_lora_modules}"
f" but received {unexpected_modules}."
f" Please verify that the loaded LoRA module is correct"
)
# Load tensors if there are only expected modules. # Load tensors if there are only expected modules.
check_unexpected_modules(f)
for module in f.keys(): # noqa for module in f.keys(): # noqa
tensors[module] = f.get_tensor(module) tensors[module] = f.get_tensor(module)
elif os.path.isfile(lora_bin_file_path): elif os.path.isfile(lora_bin_file_path):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment