Commit d2b52805 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.10.2rc1' into v0.10.2rc1-ori

parents 9a521c23 5438967f
......@@ -8,7 +8,7 @@ from typing import Any, Optional
import regex as re
from transformers import PreTrainedTokenizerBase
from vllm.entrypoints.chat_utils import random_tool_call_id
from vllm.entrypoints.chat_utils import make_tool_call_id
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaMessage,
ExtractedToolCallInformation,
......@@ -74,7 +74,7 @@ class Phi4MiniJsonToolParser(ToolParser):
tool_calls: list[ToolCall] = [
ToolCall(
id=random_tool_call_id(),
id=make_tool_call_id(),
type="function",
function=FunctionCall(
name=raw_function_call["name"],
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import ast
import json
import uuid
from collections.abc import Sequence
......@@ -22,7 +22,7 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer
logger = init_logger(__name__)
@ToolParserManager.register_module(["qwen3_coder"])
@ToolParserManager.register_module("qwen3_coder")
class Qwen3CoderToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer):
......@@ -30,6 +30,8 @@ class Qwen3CoderToolParser(ToolParser):
self.current_tool_name_sent: bool = False
self.prev_tool_call_arr: list[dict] = []
# Override base class type - we use string IDs for tool calls
self.current_tool_id: Optional[str] = None # type: ignore
self.streamed_args_for_tool: list[str] = []
# Sentinel tokens for streaming mode
......@@ -42,20 +44,6 @@ class Qwen3CoderToolParser(ToolParser):
self.is_tool_call_started: bool = False
self.failed_count: int = 0
# Streaming state variables
self.current_tool_index: int = 0
self.header_sent: bool = False
self.current_tool_string_id: Optional[str] = None
self.current_function_name: Optional[str] = None
self.current_param_name: Optional[str] = None
self.current_param_value: str = ""
self.param_count: int = 0
self.in_param: bool = False
self.in_function: bool = False
self.accumulated_text: str = ""
self.json_started: bool = False
self.json_closed: bool = False
# Enhanced streaming state - reset for each new message
self._reset_streaming_state()
......@@ -67,7 +55,8 @@ class Qwen3CoderToolParser(ToolParser):
self.tool_call_function_regex = re.compile(
r"<function=(.*?)</function>|<function=(.*)$", re.DOTALL)
self.tool_call_parameter_regex = re.compile(
r"<parameter=(.*?)</parameter>|<parameter=(.*?)$", re.DOTALL)
r"<parameter=(.*?)(?:</parameter>|(?=<parameter=)|(?=</function>)|$)",
re.DOTALL)
if not self.model_tokenizer:
raise ValueError(
......@@ -84,7 +73,7 @@ class Qwen3CoderToolParser(ToolParser):
"Qwen3 XML Tool parser could not locate tool call start/end "
"tokens in the tokenizer!")
logger.debug("vLLM Successfully import tool parser %s !",
logger.info("vLLM Successfully import tool parser %s !",
self.__class__.__name__)
def _generate_tool_call_id(self) -> str:
......@@ -96,7 +85,7 @@ class Qwen3CoderToolParser(ToolParser):
self.current_tool_index = 0
self.is_tool_call_started = False
self.header_sent = False
self.current_tool_string_id = None
self.current_tool_id = None
self.current_function_name = None
self.current_param_name = None
self.current_param_value = ""
......@@ -106,22 +95,21 @@ class Qwen3CoderToolParser(ToolParser):
self.accumulated_text = ""
self.json_started = False
self.json_closed = False
def _parse_xml_function_call(
self, function_call_str: str,
tools: Optional[list[ChatCompletionToolsParam]]
) -> Optional[ToolCall]:
def get_arguments_config(func_name: str) -> dict:
# Store accumulated parameters for type conversion
self.accumulated_params = {}
self.streaming_request = None
def _get_arguments_config(
self, func_name: str,
tools: Optional[list[ChatCompletionToolsParam]]) -> dict:
"""Extract argument configuration for a function."""
if tools is None:
return {}
for config in tools:
if not hasattr(config, "type") or not (
hasattr(config, "function")
and hasattr(config.function, "name")):
if not hasattr(config, "type") or not (hasattr(
config, "function") and hasattr(config.function, "name")):
continue
if (config.type == "function"
and config.function.name == func_name):
if config.type == "function" and config.function.name == func_name:
if not hasattr(config.function, "parameters"):
return {}
params = config.function.parameters
......@@ -135,14 +123,13 @@ class Qwen3CoderToolParser(ToolParser):
func_name)
return {}
def convert_param_value(param_value: str, param_name: str,
def _convert_param_value(self, param_value: str, param_name: str,
param_config: dict, func_name: str) -> Any:
"""Convert parameter value based on its type in the schema."""
# Handle null value for any type
if param_value.lower() == "null":
return None
converted_value: Any
if param_name not in param_config:
if param_config != {}:
logger.warning(
......@@ -151,38 +138,31 @@ class Qwen3CoderToolParser(ToolParser):
"string value.", param_name, func_name)
return param_value
if (isinstance(param_config[param_name], dict)
and "type" in param_config[param_name]):
param_type = str(
param_config[param_name]["type"]).strip().lower()
if isinstance(param_config[param_name],
dict) and "type" in param_config[param_name]:
param_type = str(param_config[param_name]["type"]).strip().lower()
else:
param_type = "string"
if param_type in [
"string", "str", "text", "varchar", "char", "enum"
]:
if param_type in ["string", "str", "text", "varchar", "char", "enum"]:
return param_value
elif (param_type.startswith("int") or param_type.startswith("uint")
or param_type.startswith("long")
or param_type.startswith("short")
or param_type.startswith("unsigned")):
elif param_type.startswith("int") or param_type.startswith(
"uint") or param_type.startswith(
"long") or param_type.startswith(
"short") or param_type.startswith("unsigned"):
try:
converted_value = int(param_value)
return converted_value
except ValueError:
return int(param_value)
except (ValueError, TypeError):
logger.warning(
"Parsed value '%s' of parameter '%s' is not an "
"integer in tool '%s', degenerating to string.",
param_value, param_name, func_name)
return param_value
elif (param_type.startswith("num")
or param_type.startswith("float")):
elif param_type.startswith("num") or param_type.startswith("float"):
try:
float_param_value = float(param_value)
converted_value = (float_param_value if float_param_value -
int(float_param_value) != 0 else
int(float_param_value))
return converted_value
except ValueError:
return float_param_value if float_param_value - int(
float_param_value) != 0 else int(float_param_value)
except (ValueError, TypeError):
logger.warning(
"Parsed value '%s' of parameter '%s' is not a float "
"in tool '%s', degenerating to string.", param_value,
......@@ -192,36 +172,45 @@ class Qwen3CoderToolParser(ToolParser):
param_value = param_value.lower()
if param_value not in ["true", "false"]:
logger.warning(
"Parsed value '%s' of parameter '%s' is not a "
"boolean (`true` of `false`) in tool '%s', "
"degenerating to false.", param_value, param_name,
func_name)
"Parsed value '%s' of parameter '%s' is not a boolean "
"(`true` or `false`) in tool '%s', degenerating to "
"false.", param_value, param_name, func_name)
return param_value == "true"
else:
if param_type == "object" or param_type.startswith("dict"):
if param_type in ["object", "array", "arr"
] or param_type.startswith(
"dict") or param_type.startswith("list"):
try:
converted_value = json.loads(param_value)
return converted_value
except json.JSONDecodeError:
param_value = json.loads(param_value)
return param_value
except (json.JSONDecodeError, TypeError, ValueError):
logger.warning(
"Parsed value '%s' of parameter '%s' is not a "
"valid JSON object in tool '%s', will try other "
"methods to parse it.", param_value, param_name,
"Parsed value '%s' of parameter '%s' cannot be "
"parsed with json.loads in tool '%s', will try "
"other methods to parse it.", param_value, param_name,
func_name)
try:
param_value = ast.literal_eval(param_value) # safer
except (ValueError, SyntaxError, TypeError):
logger.warning(
"Parameter '%s' has unknown type '%s'. "
"The value will be treated as a string.", param_name,
param_type)
"Parsed value '%s' of parameter '%s' cannot be "
"converted via Python `ast.literal_eval()` in tool "
"'%s', degenerating to string.", param_value, param_name,
func_name)
return param_value
def _parse_xml_function_call(
self, function_call_str: str,
tools: Optional[list[ChatCompletionToolsParam]]
) -> Optional[ToolCall]:
# Extract function name
end_index = function_call_str.index(">")
function_name = function_call_str[:end_index]
param_config = get_arguments_config(function_name)
param_config = self._get_arguments_config(function_name, tools)
parameters = function_call_str[end_index + 1:]
param_dict = {}
for match in self.tool_call_parameter_regex.findall(parameters):
match_text = match[0] if match[0] else match[1]
for match_text in self.tool_call_parameter_regex.findall(parameters):
idx = match_text.index(">")
param_name = match_text[:idx]
param_value = str(match_text[idx + 1:])
......@@ -231,7 +220,7 @@ class Qwen3CoderToolParser(ToolParser):
if param_value.endswith("\n"):
param_value = param_value[:-1]
param_dict[param_name] = convert_param_value(
param_dict[param_name] = self._convert_param_value(
param_value, param_name, param_config, function_name)
return ToolCall(
type="function",
......@@ -284,8 +273,7 @@ class Qwen3CoderToolParser(ToolParser):
for function_call_str in function_calls
]
# Populate prev_tool_call_arr for serving layer to set
# finish_reason
# Populate prev_tool_call_arr for serving layer to set finish_reason
self.prev_tool_call_arr.clear() # Clear previous calls
for tool_call in tool_calls:
if tool_call:
......@@ -298,8 +286,8 @@ class Qwen3CoderToolParser(ToolParser):
# Extract content before tool calls
content_index = model_output.find(self.tool_call_start_token)
content_index = (content_index if content_index >= 0 else
model_output.find(self.tool_call_prefix))
idx = model_output.find(self.tool_call_prefix)
content_index = content_index if content_index >= 0 else idx
content = model_output[:content_index] # .rstrip()
return ExtractedToolCallInformation(
......@@ -324,13 +312,16 @@ class Qwen3CoderToolParser(ToolParser):
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> Union[DeltaMessage, None]:
# If no delta text, return None unless it's an EOS token after tool
# calls
# Store request for type conversion
if not previous_text:
self._reset_streaming_state()
self.streaming_request = request
# If no delta text, return None unless it's an EOS token after tools
if not delta_text:
# Check if this is an EOS token after all tool calls are complete
# We check for tool calls in the text even if is_tool_call_started
# is False because it might have been reset after processing all
# tools
# Check for tool calls in text even if is_tool_call_started
# is False (might have been reset after processing all tools)
if (delta_token_ids
and self.tool_call_end_token_id not in delta_token_ids):
# Count complete tool calls
......@@ -339,24 +330,19 @@ class Qwen3CoderToolParser(ToolParser):
# If we have completed tool calls and populated
# prev_tool_call_arr
if (complete_calls > 0 and len(self.prev_tool_call_arr) > 0):
if complete_calls > 0 and len(self.prev_tool_call_arr) > 0:
# Check if all tool calls are closed
open_calls = (
current_text.count(self.tool_call_start_token) -
current_text.count(self.tool_call_end_token))
open_calls = current_text.count(
self.tool_call_start_token) - current_text.count(
self.tool_call_end_token)
if open_calls == 0:
# Return empty delta message to allow finish_reason
# processing
# Return empty delta for finish_reason processing
return DeltaMessage(content="")
elif not self.is_tool_call_started and current_text:
# This is a regular content response that's now complete
return DeltaMessage(content="")
return None
# Check if this is the first call (reset state if needed)
if not previous_text:
self._reset_streaming_state()
# Update accumulated text
self.accumulated_text = current_text
......@@ -371,11 +357,11 @@ class Qwen3CoderToolParser(ToolParser):
self.param_count = 0
self.json_started = False
self.json_closed = False
self.accumulated_params = {}
# Check if there are more tool calls
tool_starts_count = current_text.count(
self.tool_call_start_token)
if self.current_tool_index >= tool_starts_count:
tool_starts = current_text.count(self.tool_call_start_token)
if self.current_tool_index >= tool_starts:
# No more tool calls
self.is_tool_call_started = False
# Continue processing next tool
......@@ -412,20 +398,20 @@ class Qwen3CoderToolParser(ToolParser):
# We're in a tool call, find the current tool call portion
# Need to find the correct tool call based on current_tool_index
tool_starts: list[int] = []
tool_start_positions: list[int] = []
idx = 0
while True:
idx = current_text.find(self.tool_call_start_token, idx)
if idx == -1:
break
tool_starts.append(idx)
tool_start_positions.append(idx)
idx += len(self.tool_call_start_token)
if self.current_tool_index >= len(tool_starts):
if self.current_tool_index >= len(tool_start_positions):
# No more tool calls to process yet
return None
tool_start_idx = tool_starts[self.current_tool_index]
tool_start_idx = tool_start_positions[self.current_tool_index]
# Find where this tool call ends (or current position if not ended yet)
tool_end_idx = current_text.find(self.tool_call_end_token,
tool_start_idx)
......@@ -438,19 +424,19 @@ class Qwen3CoderToolParser(ToolParser):
# Looking for function header
if not self.header_sent:
if self.tool_call_prefix in tool_text:
func_start = (tool_text.find(self.tool_call_prefix) +
len(self.tool_call_prefix))
func_start = tool_text.find(self.tool_call_prefix) + len(
self.tool_call_prefix)
func_end = tool_text.find(">", func_start)
if func_end != -1:
# Found complete function name
self.current_function_name = tool_text[func_start:func_end]
self.current_tool_string_id = self._generate_tool_call_id()
self.current_tool_id = self._generate_tool_call_id()
self.header_sent = True
self.in_function = True
# IMPORTANT: Add to prev_tool_call_arr immediately when we
# detect a tool call. This ensures
# IMPORTANT: Add to prev_tool_call_arr immediately when
# we detect a tool call. This ensures
# finish_reason="tool_calls" even if parsing isn't complete
already_added = any(
tool.get("name") == self.current_function_name
......@@ -466,7 +452,7 @@ class Qwen3CoderToolParser(ToolParser):
return DeltaMessage(tool_calls=[
DeltaToolCall(
index=self.current_tool_index,
id=self.current_tool_string_id,
id=self.current_tool_id,
function=DeltaFunctionCall(
name=self.current_function_name, arguments=""),
type="function",
......@@ -496,10 +482,11 @@ class Qwen3CoderToolParser(ToolParser):
# Close JSON
self.json_closed = True
# Extract the complete tool call to update prev_tool_call_arr
# with final arguments. Find the function content
func_start = (tool_text.find(self.tool_call_prefix) +
len(self.tool_call_prefix))
# Extract complete tool call to update
# prev_tool_call_arr with final arguments
# Find the function content
func_start = tool_text.find(self.tool_call_prefix) + len(
self.tool_call_prefix)
func_content_end = tool_text.find(self.function_end_token,
func_start)
if func_content_end != -1:
......@@ -507,15 +494,17 @@ class Qwen3CoderToolParser(ToolParser):
# Parse to get the complete arguments
try:
parsed_tool = self._parse_xml_function_call(
func_content, request.tools if request else None)
func_content, self.streaming_request.tools
if self.streaming_request else None)
if parsed_tool:
# Update existing entry in prev_tool_call_arr with
# complete arguments
# Update existing entry in
# prev_tool_call_arr with complete args
for i, tool in enumerate(self.prev_tool_call_arr):
if (tool.get("name") ==
parsed_tool.function.name):
self.prev_tool_call_arr[i]["arguments"] = (
parsed_tool.function.arguments)
if tool.get(
"name") == parsed_tool.function.name:
args = parsed_tool.function.arguments
self.prev_tool_call_arr[i][
"arguments"] = args
break
except Exception:
pass # Ignore parsing errors during streaming
......@@ -530,17 +519,12 @@ class Qwen3CoderToolParser(ToolParser):
# Reset state for next tool
self.in_function = False
self.json_closed = True
self.accumulated_params = {}
return result
# Look for parameters
# Count how many complete parameters we have processed
complete_params = tool_text.count(self.parameter_end_token)
# Check if we should start a new parameter
if not self.in_param and self.param_count < complete_params:
# Find the unprocessed parameter
# Count parameter starts
# Find all parameter starts
param_starts = []
idx = 0
while True:
......@@ -550,7 +534,9 @@ class Qwen3CoderToolParser(ToolParser):
param_starts.append(idx)
idx += len(self.parameter_prefix)
if len(param_starts) > self.param_count:
# Check if we should start a new parameter
if (not self.in_param and self.param_count < len(param_starts)
and len(param_starts) > self.param_count):
# Process the next parameter
param_idx = param_starts[self.param_count]
param_start = param_idx + len(self.parameter_prefix)
......@@ -568,23 +554,62 @@ class Qwen3CoderToolParser(ToolParser):
value_text = value_text[1:]
# Find where this parameter ends
param_end_idx = value_text.find(
self.parameter_end_token)
param_end_idx = value_text.find(self.parameter_end_token)
if param_end_idx == -1:
# No closing tag, look for next parameter or
# function end
next_param_idx = value_text.find(self.parameter_prefix)
func_end_idx = value_text.find(self.function_end_token)
if next_param_idx != -1 and (func_end_idx == -1
or next_param_idx
< func_end_idx):
param_end_idx = next_param_idx
elif func_end_idx != -1:
param_end_idx = func_end_idx
else:
# Neither found, check if tool call is complete
if self.tool_call_end_token in tool_text:
# Tool call is complete, so parameter
# must be complete too. Use all
# remaining text before function end
param_end_idx = len(value_text)
else:
# Still streaming, wait for more content
return None
if param_end_idx != -1:
# Complete parameter found
param_value = value_text[:param_end_idx]
if param_value.endswith("\n"):
param_value = param_value[:-1]
# Build complete JSON fragment for this parameter
# Store raw value for later processing
self.accumulated_params[
self.current_param_name] = param_value
# Get parameter configuration for type conversion
param_config = self._get_arguments_config(
self.current_function_name or "",
self.streaming_request.tools
if self.streaming_request else None)
# Convert param value to appropriate type
converted_value = self._convert_param_value(
param_value, self.current_param_name, param_config,
self.current_function_name or "")
# Build JSON fragment based on the converted type
# Use json.dumps to properly serialize the value
serialized_value = json.dumps(converted_value,
ensure_ascii=False)
if self.param_count == 0:
json_fragment = (
'"' + self.current_param_name + '": "' +
json.dumps(param_value)[1:-1] + '"')
json_fragment = (f'"{self.current_param_name}": '
f'{serialized_value}')
else:
json_fragment = (
', "' + self.current_param_name + '": "' +
json.dumps(param_value)[1:-1] + '"')
json_fragment = (f', "{self.current_param_name}": '
f'{serialized_value}')
self.param_count += 1
......@@ -596,7 +621,8 @@ class Qwen3CoderToolParser(ToolParser):
)
])
# Continue parameter value
# Continue parameter value - Not used in the current implementation
# since we process complete parameters above
if self.in_param:
if self.parameter_end_token in delta_text:
# End of parameter
......@@ -608,25 +634,42 @@ class Qwen3CoderToolParser(ToolParser):
gt_idx = value_chunk.find(">")
value_chunk = value_chunk[gt_idx + 1:]
if (not self.current_param_value
and value_chunk.startswith("\n")):
if not self.current_param_value and value_chunk.startswith(
"\n"):
value_chunk = value_chunk[1:]
# Calculate incremental JSON
# Store complete value
full_value = self.current_param_value + value_chunk
prev_escaped = (json.dumps(self.current_param_value)[1:-1]
if self.current_param_value else "")
full_escaped = json.dumps(full_value)[1:-1]
delta_escaped = full_escaped[len(prev_escaped):]
self.accumulated_params[
self.current_param_name] = full_value
# Get parameter configuration for type conversion
param_config = self._get_arguments_config(
self.current_function_name or "",
self.streaming_request.tools
if self.streaming_request else None)
# Convert the parameter value to the appropriate type
converted_value = self._convert_param_value(
full_value, self.current_param_name or "",
param_config, self.current_function_name or "")
# Serialize the converted value
serialized_value = json.dumps(converted_value,
ensure_ascii=False)
# Since we've been streaming the quoted version,
# we need to close it properly
# This is complex - for now just complete the value
self.in_param = False
self.current_param_value = ""
# Just close the current parameter string
return DeltaMessage(tool_calls=[
DeltaToolCall(
index=self.current_tool_index,
function=DeltaFunctionCall(
arguments=delta_escaped + '"'),
arguments='"'), # Close the string quote
)
])
else:
......@@ -638,18 +681,18 @@ class Qwen3CoderToolParser(ToolParser):
gt_idx = value_chunk.find(">")
value_chunk = value_chunk[gt_idx + 1:]
if (not self.current_param_value
and value_chunk.startswith("\n")):
if not self.current_param_value and value_chunk.startswith(
"\n"):
value_chunk = value_chunk[1:]
if value_chunk:
# Stream the escaped delta
prev_escaped = (json.dumps(
self.current_param_value)[1:-1]
if self.current_param_value else "")
prev_escaped = json.dumps(
self.current_param_value, ensure_ascii=False
)[1:-1] if self.current_param_value else ""
self.current_param_value += value_chunk
full_escaped = json.dumps(
self.current_param_value)[1:-1]
full_escaped = json.dumps(self.current_param_value,
ensure_ascii=False)[1:-1]
delta_escaped = full_escaped[len(prev_escaped):]
if delta_escaped:
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Adapted from qwen3coder xml parser, All rights reserved.
# ruff: noqa: E501
import ast
import json
import uuid
from collections.abc import Sequence
from typing import Any, Optional, Union
import regex as re
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
ChatCompletionToolsParam,
DeltaFunctionCall, DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall, ToolCall)
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser, ToolParserManager)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
logger = init_logger(__name__)
@ToolParserManager.register_module("seed_oss")
class SeedOssToolParser(ToolParser):
TOOL_CALL_START = "<seed:tool_call>"
TOOL_CALL_END = "</seed:tool_call>"
def __init__(self, tokenizer: AnyTokenizer):
super().__init__(tokenizer)
# --- streaming state ---
self._reset_streaming_state()
self.prev_tool_call_arr: list[dict] = []
self.tool_call_start_token: str = self.TOOL_CALL_START
self.tool_call_end_token: str = self.TOOL_CALL_END
# Sentinel tokens for streaming mode
self.tool_call_prefix: str = "<function="
self.function_end_token: str = "</function>"
self.parameter_prefix: str = "<parameter="
self.parameter_end_token: str = "</parameter>"
self.think_start_token: str = "<seed:think>"
self.think_end_token: str = "</seed:think>"
self.is_tool_call_started: bool = False
self.is_thinking_end: bool = False
self.failed_count: int = 0
self._reset_streaming_state()
self.tool_call_start_token_id = self.vocab.get(
self.tool_call_start_token)
self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
self.think_end_token_id = self.vocab.get(self.think_end_token)
if (self.tool_call_start_token_id is None
or self.tool_call_end_token_id is None):
raise RuntimeError(
"Seed_Oss XML parser: tokenizer did not include "
"<seed:tool_call> or its closing tag.")
tool_start_re = re.escape(self.tool_call_start_token)
tool_end_re = re.escape(self.tool_call_end_token)
self.tool_call_complete_regex = re.compile(
rf"{tool_start_re}(.*?){tool_end_re}", re.DOTALL)
self.tool_call_regex = re.compile(
rf"{tool_start_re}(.*?){tool_end_re}|{tool_start_re}(.*?)$",
re.DOTALL)
self.tool_call_function_regex = re.compile(
r"<function=(.*?)</function>|<function=(.*)$", re.DOTALL)
self.tool_call_parameter_regex = re.compile(
r"<parameter=(.*?)</parameter>|<parameter=(.*?)$", re.DOTALL)
logger.info("vLLM Seed-Oss XML tool parser loaded (%s).",
self.__class__.__name__)
def _generate_tool_call_id(self) -> str:
"""Generate a unique tool call ID."""
return f"call_{uuid.uuid4().hex[:24]}"
def _reset_streaming_state(self):
"""Reset all streaming state."""
self.current_tool_index = 0
self.is_tool_call_started = False
self.header_sent = False
self.current_tool_id = -1
self.current_function_name = None
self.current_param_name = None
self.current_param_value = ""
self.param_count = 0
self.in_param = False
self.in_function = False
self.accumulated_text = ""
self.json_started = False
self.json_closed = False
def _parse_xml_function_call(
self, function_call_str: str,
tools: Optional[list[ChatCompletionToolsParam]]
) -> Optional[ToolCall]:
def get_arguments_config(func_name: str) -> dict:
if tools is None:
return {}
for config in tools:
if not hasattr(config, "type") or not (
hasattr(config, "function")
and hasattr(config.function, "name")):
continue
if (config.type == "function"
and config.function.name == func_name):
if not hasattr(config.function, "parameters"):
return {}
params = config.function.parameters
if isinstance(params, dict) and "properties" in params:
return params["properties"]
elif isinstance(params, dict):
return params
else:
return {}
logger.warning("Tool '%s' is not defined in the tools list.",
func_name)
return {}
def convert_param_value(param_value: str, param_name: str,
param_config: dict, func_name: str) -> Any:
# Handle null value for any type
if param_value.lower() == "null":
return None
if param_name not in param_config:
if param_config != {}:
logger.warning(
"Parsed parameter '%s' is not defined in "
"the tool parameters for tool '%s', "
"directly returning the string value.", param_name,
func_name)
return param_value
if (isinstance(param_config[param_name], dict)
and "type" in param_config[param_name]):
param_type = str(
param_config[param_name]["type"]).strip().lower()
else:
param_type = "string"
if param_type in [
"string", "str", "text", "varchar", "char", "enum"
]:
return param_value
elif (param_type.startswith("int") or param_type.startswith("uint")
or param_type.startswith("long")
or param_type.startswith("short")
or param_type.startswith("unsigned")):
try:
param_value = int(param_value) # type: ignore
except (ValueError, TypeError):
logger.warning(
"Parsed value '%s' of parameter '%s' is not an integer in tool "
"'%s', degenerating to string.", param_value,
param_name, func_name)
return param_value
elif param_type.startswith("num") or param_type.startswith(
"float"):
try:
float_param_value = float(param_value)
param_value = float_param_value if float_param_value - int(
float_param_value) != 0 else int(
float_param_value) # type: ignore
except (ValueError, TypeError):
logger.warning(
"Parsed value '%s' of parameter '%s' is not a float in tool "
"'%s', degenerating to string.", param_value,
param_name, func_name)
return param_value
elif param_type in ["boolean", "bool", "binary"]:
param_value = param_value.lower()
if param_value not in ["true", "false"]:
logger.warning(
"Parsed value '%s' of parameter '%s' is not a boolean "
"(`true` of `false`) in tool '%s', degenerating to false.",
param_value, param_name, func_name)
return param_value == "true"
else:
if param_type == "object" or param_type.startswith("dict"):
try:
param_value = json.loads(param_value)
return param_value
except (ValueError, TypeError, json.JSONDecodeError):
logger.warning(
"Parsed value '%s' of parameter '%s' is not a valid JSON "
"object in tool '%s', will try other methods to parse it.",
param_value, param_name, func_name)
try:
param_value = ast.literal_eval(param_value)
except (ValueError, SyntaxError):
logger.warning(
"Parsed value '%s' of parameter '%s' cannot be converted via "
"Python `ast.literal_eval()` in tool '%s', degenerating to string.",
param_value, param_name, func_name)
return param_value
# Extract function name
end_index = function_call_str.index(">")
function_name = function_call_str[:end_index]
param_config = get_arguments_config(function_name)
parameters = function_call_str[end_index + 1:]
param_dict = {}
for match in self.tool_call_parameter_regex.findall(parameters):
match_text = match[0] if match[0] else match[1]
idx = match_text.index(">")
param_name = match_text[:idx]
param_value = str(match_text[idx + 1:])
# Remove prefix and trailing \n
if param_value.startswith("\n"):
param_value = param_value[1:]
if param_value.endswith("\n"):
param_value = param_value[:-1]
param_dict[param_name] = convert_param_value(
param_value, param_name, param_config, function_name)
return ToolCall(
type="function",
function=FunctionCall(name=function_name,
arguments=json.dumps(param_dict,
ensure_ascii=False)),
)
def _get_function_calls(self, model_output: str) -> list[str]:
# Find all tool calls
matched_ranges = self.tool_call_regex.findall(model_output)
raw_tool_calls = [
match[0] if match[0] else match[1] for match in matched_ranges
]
# Back-off strategy if no tool_call tags found
if len(raw_tool_calls) == 0:
raw_tool_calls = [model_output]
raw_function_calls = []
for tool_call in raw_tool_calls:
raw_function_calls.extend(
self.tool_call_function_regex.findall(tool_call))
function_calls = [
match[0] if match[0] else match[1] for match in raw_function_calls
]
return function_calls
def extract_tool_calls(
self,
model_output: str,
request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:
# Quick check to avoid unnecessary processing
if self.tool_call_prefix not in model_output:
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)
# Check if both think start and end tokens are present
if (self.think_start_token in model_output
and self.think_end_token in model_output):
# Find the position of think end token
think_end_index = model_output.find(self.think_end_token) + len(
self.think_end_token)
# Extract content after think end token
result_content = model_output[think_end_index:]
thinking_content = model_output[:think_end_index]
else:
thinking_content = ""
result_content = model_output
try:
function_calls = self._get_function_calls(result_content)
if len(function_calls) == 0:
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)
tool_calls = [
self._parse_xml_function_call(function_call_str, request.tools)
for function_call_str in function_calls
]
# Populate prev_tool_call_arr for serving layer to set finish_reason
self.prev_tool_call_arr.clear() # Clear previous calls
for tool_call in tool_calls:
if tool_call:
self.prev_tool_call_arr.append({
"name":
tool_call.function.name,
"arguments":
tool_call.function.arguments,
})
# Extract content before tool calls
tool_call_start_index = result_content.find(
self.tool_call_start_token)
tool_call_start_index = (
tool_call_start_index if tool_call_start_index >= 0 else
result_content.find(self.tool_call_prefix))
content = thinking_content + result_content[:tool_call_start_index]
return ExtractedToolCallInformation(
tools_called=(len(tool_calls) > 0),
tool_calls=tool_calls,
content=content if content else None,
)
except Exception:
logger.exception("Error in extracting tool call from response.")
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> Union[DeltaMessage, None]:
# If no delta text, return None unless
# it's an EOS token after tool calls
if not delta_text:
# Check if this is an EOS token after all tool calls are complete
# We check for tool calls in the text even if is_tool_call_started
# is False because it might have been reset after processing all tools
if (delta_token_ids
and self.tool_call_end_token_id not in delta_token_ids):
# Count complete tool calls
complete_calls = len(
self.tool_call_complete_regex.findall(current_text))
# If we have completed tool calls and populated prev_tool_call_arr
if complete_calls > 0 and len(self.prev_tool_call_arr) > 0:
# Check if all tool calls are closed
open_calls = current_text.count(
self.tool_call_start_token) - current_text.count(
self.tool_call_end_token)
if open_calls == 0:
# Return empty delta message to allow finish_reason processing
return DeltaMessage(content="")
elif not self.is_tool_call_started and current_text:
# This is a regular content response that's now complete
return DeltaMessage(content="")
return None
# Check if this is the first call (reset state if needed)
if not previous_text:
self._reset_streaming_state()
# Update accumulated text
self.accumulated_text = current_text
# Check if we need to advance to next tool
if self.json_closed and not self.in_function:
# Check if this tool call has ended
tool_ends = current_text.count(self.tool_call_end_token)
if tool_ends > self.current_tool_index:
# This tool has ended, advance to next
self.current_tool_index += 1
self.header_sent = False
self.param_count = 0
self.json_started = False
self.json_closed = False
# Check if there are more tool calls
if self.current_tool_index >= current_text.count(
self.tool_call_start_token):
# No more tool calls
self.is_tool_call_started = False
# Continue processing next tool
return None
# Check if end thinking
if (not self.is_thinking_end
and (self.think_end_token_id in delta_token_ids
or self.think_end_token in delta_text)):
self.is_thinking_end = True
# If thinking hasn't ended yet, don't process any tool calls
if not self.is_thinking_end:
return DeltaMessage(content=delta_text)
# Handle normal content before tool calls
if not self.is_tool_call_started:
# Check if tool call is starting
if (self.tool_call_start_token_id in delta_token_ids
or self.tool_call_start_token in delta_text):
self.is_tool_call_started = True
# Return any content before the tool call
if self.tool_call_start_token in delta_text:
content_before = delta_text[:delta_text.index(
self.tool_call_start_token)]
if content_before:
return DeltaMessage(content=content_before)
return None
else:
# Check if we're between tool calls - skip whitespace
if (current_text.rstrip().endswith(self.tool_call_end_token)
and delta_text.strip() == ""):
# We just ended a tool call, skip whitespace
return None
# Normal content, no tool call
return DeltaMessage(content=delta_text)
# Check if we're between tool calls (waiting for next one)
# Count tool calls we've seen vs processed
tool_starts_count = current_text.count(self.tool_call_start_token)
if self.current_tool_index >= tool_starts_count:
# We're past all tool calls, shouldn't be here
return None
# We're in a tool call, find the current tool call portion
# Need to find the correct tool call based on current_tool_index
# Only process tool calls after think_end_token
think_end_index = current_text.find(self.think_end_token) + len(
self.think_end_token
) if self.think_end_token in current_text else 0
tool_starts: list[int] = []
idx = think_end_index
while True:
idx = current_text.find(self.tool_call_start_token, idx)
if idx == -1:
break
tool_starts.append(idx)
idx += len(self.tool_call_start_token)
if self.current_tool_index >= len(tool_starts):
# No more tool calls to process yet
return None
tool_start_idx = tool_starts[self.current_tool_index]
# Find where this tool call ends (or current position if not ended yet)
tool_end_idx = current_text.find(self.tool_call_end_token,
tool_start_idx)
if tool_end_idx == -1:
tool_text = current_text[tool_start_idx:]
else:
tool_text = current_text[tool_start_idx:tool_end_idx +
len(self.tool_call_end_token)]
# Looking for function header
if not self.header_sent:
if self.tool_call_prefix in tool_text:
func_start = tool_text.find(self.tool_call_prefix) + len(
self.tool_call_prefix)
func_end = tool_text.find(">", func_start)
if func_end != -1:
# Found complete function name
self.current_function_name = tool_text[func_start:func_end]
self.current_tool_id = self._generate_tool_call_id(
) # type: ignore
self.header_sent = True
self.in_function = True
# IMPORTANT: Add to prev_tool_call_arr immediately when we detect a tool call
# This ensures finish_reason="tool_calls" even if parsing isn't complete
already_added = any(
tool.get("name") == self.current_function_name
for tool in self.prev_tool_call_arr)
if not already_added:
self.prev_tool_call_arr.append({
"name": self.current_function_name,
"arguments":
"{}", # Placeholder, will be updated later
})
# Send header with function info
return DeltaMessage(tool_calls=[
DeltaToolCall(
index=self.current_tool_index,
id=self.current_tool_id,
function=DeltaFunctionCall(
name=self.current_function_name, arguments=""),
type="function",
)
])
return None
# We've sent header, now handle function body
if self.in_function:
# Send opening brace if not sent yet
if (not self.json_started
and self.parameter_prefix not in delta_text):
self.json_started = True
return DeltaMessage(tool_calls=[
DeltaToolCall(
index=self.current_tool_index,
function=DeltaFunctionCall(arguments="{"),
)
])
# Make sure json_started is set if we're processing parameters
if not self.json_started:
self.json_started = True
# Check for function end in accumulated text
if not self.json_closed and self.function_end_token in tool_text:
# Close JSON
self.json_closed = True
# Extract the complete tool call to update prev_tool_call_arr with final arguments
# Find the function content
func_start = tool_text.find(self.tool_call_prefix) + len(
self.tool_call_prefix)
func_content_end = tool_text.find(self.function_end_token,
func_start)
if func_content_end != -1:
func_content = tool_text[func_start:func_content_end]
# Parse to get the complete arguments
try:
parsed_tool = self._parse_xml_function_call(
func_content, request.tools if request else None)
if parsed_tool:
# Update existing entry in prev_tool_call_arr with complete arguments
for i, tool in enumerate(self.prev_tool_call_arr):
if tool.get(
"name") == parsed_tool.function.name:
self.prev_tool_call_arr[i]["arguments"] = (
parsed_tool.function.arguments)
break
except Exception:
logger.warning(
"Failed to parse tool arguments during streaming.",
exc_info=True)
result = DeltaMessage(tool_calls=[
DeltaToolCall(
index=self.current_tool_index,
function=DeltaFunctionCall(arguments="}"),
)
])
# Reset state for next tool
self.in_function = False
self.json_closed = True
return result
# Look for parameters
# Count how many complete parameters we have processed
complete_params = tool_text.count(self.parameter_end_token)
# Check if we should start a new parameter
if not self.in_param and self.param_count < complete_params:
# Find the unprocessed parameter
# Count parameter starts
param_starts = []
idx = 0
while True:
idx = tool_text.find(self.parameter_prefix, idx)
if idx == -1:
break
param_starts.append(idx)
idx += len(self.parameter_prefix)
if len(param_starts) > self.param_count:
# Process the next parameter
param_idx = param_starts[self.param_count]
param_start = param_idx + len(self.parameter_prefix)
remaining = tool_text[param_start:]
if ">" in remaining:
# We have the complete parameter name
name_end = remaining.find(">")
self.current_param_name = remaining[:name_end]
# Find the parameter value
value_start = param_start + name_end + 1
value_text = tool_text[value_start:]
if value_text.startswith("\n"):
value_text = value_text[1:]
# Find where this parameter ends
param_end_idx = value_text.find(
self.parameter_end_token)
if param_end_idx != -1:
# Complete parameter found
param_value = value_text[:param_end_idx]
if param_value.endswith("\n"):
param_value = param_value[:-1]
# Build complete JSON fragment for this parameter
if self.param_count == 0:
json_fragment = (
'"' + self.current_param_name + '": "' +
json.dumps(param_value)[1:-1] + '"')
else:
json_fragment = (
', "' + self.current_param_name + '": "' +
json.dumps(param_value)[1:-1] + '"')
self.param_count += 1
return DeltaMessage(tool_calls=[
DeltaToolCall(
index=self.current_tool_index,
function=DeltaFunctionCall(
arguments=json_fragment),
)
])
# Continue parameter value
if self.in_param:
if self.parameter_end_token in delta_text:
# End of parameter
end_idx = delta_text.find(self.parameter_end_token)
value_chunk = delta_text[:end_idx]
# Skip past > if at start
if not self.current_param_value and ">" in value_chunk:
gt_idx = value_chunk.find(">")
value_chunk = value_chunk[gt_idx + 1:]
if not self.current_param_value and value_chunk.startswith(
"\n"):
value_chunk = value_chunk[1:]
# Calculate incremental JSON
full_value = self.current_param_value + value_chunk
prev_escaped = (json.dumps(self.current_param_value)[1:-1]
if self.current_param_value else "")
full_escaped = json.dumps(full_value)[1:-1]
delta_escaped = full_escaped[len(prev_escaped):]
self.in_param = False
self.current_param_value = ""
return DeltaMessage(tool_calls=[
DeltaToolCall(
index=self.current_tool_index,
function=DeltaFunctionCall(
arguments=delta_escaped + '"'),
)
])
else:
# Continue accumulating value
value_chunk = delta_text
# Handle first chunk after param name
if not self.current_param_value and ">" in value_chunk:
gt_idx = value_chunk.find(">")
value_chunk = value_chunk[gt_idx + 1:]
if not self.current_param_value and value_chunk.startswith(
"\n"):
value_chunk = value_chunk[1:]
if value_chunk:
# Stream the escaped delta
prev_escaped = (json.dumps(
self.current_param_value)[1:-1]
if self.current_param_value else "")
self.current_param_value += value_chunk
full_escaped = json.dumps(
self.current_param_value)[1:-1]
delta_escaped = full_escaped[len(prev_escaped):]
if delta_escaped:
return DeltaMessage(tool_calls=[
DeltaToolCall(
index=self.current_tool_index,
function=DeltaFunctionCall(
arguments=delta_escaped),
)
])
return None
......@@ -7,7 +7,7 @@ from typing import Any, Optional, Union
import regex as re
from vllm.entrypoints.chat_utils import random_tool_call_id
from vllm.entrypoints.chat_utils import make_tool_call_id
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaFunctionCall, DeltaMessage,
DeltaToolCall,
......@@ -186,11 +186,31 @@ class xLAMToolParser(ToolParser):
"""
Extract tool calls for streaming mode.
"""
# Simplify detection: if it begins with "[" treat it as a function call
is_function_call = (current_text.strip().startswith("["))
# If not a function call, return normal content
if not is_function_call:
# First, check for a definitive start of a tool call block.
# This prevents premature parsing of incomplete output.
stripped_text = current_text.strip()
preprocessed_content, preprocessed_tool_calls = (
self.preprocess_model_output(current_text))
# For JSON code blocks, we need to detect them earlier, even if incomplete
has_potential_json_block = ("```json" in current_text
or "```\n[" in current_text
or "[TOOL_CALLS]" in current_text
or "<tool_call>" in current_text)
is_tool_call_block = (
stripped_text.startswith("[")
or stripped_text.startswith("<tool_call>")
or stripped_text.startswith("[TOOL_CALLS]") or
# Check if we have thinking tags with JSON-like content following
("</think>[" in current_text) or
# Check if the text contains a JSON array after preprocessing
preprocessed_tool_calls is not None or
# For JSON code blocks, detect early if we see enough structure
(has_potential_json_block and '"name"' in current_text
and '"arguments"' in current_text))
if not is_tool_call_block:
return DeltaMessage(content=delta_text)
try:
......@@ -204,7 +224,10 @@ class xLAMToolParser(ToolParser):
# Try parsing as JSON to check for complete tool calls
try:
parsed_tools = json.loads(current_text)
# Use preprocessed tool calls if available
tool_calls_text = (preprocessed_tool_calls if
preprocessed_tool_calls else current_text)
parsed_tools = json.loads(tool_calls_text)
if isinstance(parsed_tools, list):
# Update our tool array for next time
self.prev_tool_call_arr = parsed_tools
......@@ -226,7 +249,7 @@ class xLAMToolParser(ToolParser):
function_name = name_match.group(1)
# The test expects us to send just the name first
tool_id = random_tool_call_id()
tool_id = make_tool_call_id()
delta = DeltaMessage(tool_calls=[
DeltaToolCall(
index=0,
......@@ -257,12 +280,39 @@ class xLAMToolParser(ToolParser):
return delta
# Use regex to identify tool calls in the output
# Use preprocessed tool calls text for better parsing, but also try to extract from incomplete JSON blocks
search_text = (preprocessed_tool_calls
if preprocessed_tool_calls else current_text)
# For JSON code blocks that aren't complete yet, try to extract the JSON content
if not preprocessed_tool_calls and has_potential_json_block:
# Try to extract the JSON array from within the code block
json_match = re.search(r"```(?:json)?\s*([\s\S]*?)(?:```|$)",
current_text)
if json_match:
potential_json = json_match.group(1).strip()
# Use this as search text even if it's incomplete
if potential_json.startswith("[") and (
'"name"' in potential_json
and '"arguments"' in potential_json):
search_text = potential_json
# Try to find complete tool names first
name_pattern = r'"name"\s*:\s*"([^"]+)"'
name_matches = list(re.finditer(name_pattern, current_text))
name_matches = list(re.finditer(name_pattern, search_text))
tool_count = len(name_matches)
# If no tools found yet, return
# If no complete tool names found, check for partial tool names
if tool_count == 0:
# Check if we're in the middle of parsing a tool name
partial_name_pattern = r'"name"\s*:\s*"([^"]*)'
partial_matches = list(
re.finditer(partial_name_pattern, search_text))
if partial_matches:
# We have a partial tool name - not ready to emit yet
return None
else:
# No tools found at all
return None
# Ensure our state arrays are large enough
......@@ -332,7 +382,7 @@ class xLAMToolParser(ToolParser):
# First, check for the empty arguments case: "arguments": {}
empty_args_pattern = (
r'"name"\s*:\s*"[^"]+"\s*,\s*"arguments"\s*:\s*\{\s*\}')
empty_args_match = re.search(empty_args_pattern, current_text)
empty_args_match = re.search(empty_args_pattern, search_text)
# Check if this tool has empty arguments
if empty_args_match and empty_args_match.start() > 0:
......@@ -376,7 +426,7 @@ class xLAMToolParser(ToolParser):
# Extract arguments for current tool using regex for non-empty arguments
args_pattern = r'"name"\s*:\s*"[^"]+"\s*,\s*"arguments"\s*:\s*(\{(?:[^{}]|(?:\{[^{}]*\}))*\})'
args_matches = list(re.finditer(args_pattern, current_text))
args_matches = list(re.finditer(args_pattern, search_text))
if current_idx < len(args_matches):
args_text = args_matches[current_idx].group(1)
......@@ -384,17 +434,25 @@ class xLAMToolParser(ToolParser):
# Handle transition between tools
is_last_tool = current_idx == tool_count - 1
# Find where the arguments for our current tool end
if not is_last_tool:
# If we have more tools after this one, try to find the complete argument block
next_tool_pos = current_text.find(
"},{", args_matches[current_idx].start())
if next_tool_pos != -1:
args_end_pos = (next_tool_pos + 1
) # +1 to include the '}'
args_text = (current_text[args_matches[current_idx]
.start():args_end_pos].
split('"arguments":')[1].strip())
# For multiple tools, extract only the arguments for the current tool
if tool_count > 1:
# Parse the entire JSON structure to properly extract arguments for each tool
try:
parsed_tools = json.loads(search_text)
if isinstance(
parsed_tools,
list) and current_idx < len(parsed_tools):
current_tool = parsed_tools[current_idx]
if isinstance(current_tool.get("arguments"),
dict):
args_text = json.dumps(
current_tool["arguments"])
else:
args_text = str(
current_tool.get("arguments", "{}"))
except (json.JSONDecodeError, KeyError, IndexError):
# Fallback to regex-based extraction
pass
# If arguments haven't been sent yet
sent_args = self.streaming_state["sent_tools"][
......
......@@ -313,12 +313,14 @@ def log_non_default_args(args: Union[argparse.Namespace, EngineArgs]):
# Handle EngineArgs instance
elif isinstance(args, EngineArgs):
default_args = EngineArgs() # Create default instance
default_args = EngineArgs(model=args.model) # Create default instance
for field in dataclasses.fields(args):
current_val = getattr(args, field.name)
default_val = getattr(default_args, field.name)
if current_val != default_val:
non_default_args[field.name] = current_val
if default_args.model != EngineArgs.model:
non_default_args["model"] = default_args.model
else:
raise TypeError("Unsupported argument type. " \
"Must be argparse.Namespace or EngineArgs instance.")
......
......@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import hashlib
import json
import os
import sys
import tempfile
......@@ -42,7 +43,6 @@ if TYPE_CHECKING:
VLLM_TRACE_FUNCTION: int = 0
VLLM_ATTENTION_BACKEND: Optional[str] = None
VLLM_USE_FLASHINFER_SAMPLER: Optional[bool] = None
VLLM_FLASHINFER_FORCE_TENSOR_CORES: bool = False
VLLM_PP_LAYER_PARTITION: Optional[str] = None
VLLM_CPU_KVCACHE_SPACE: Optional[int] = 0
VLLM_CPU_OMP_THREADS_BIND: str = ""
......@@ -99,6 +99,7 @@ if TYPE_CHECKING:
VLLM_ROCM_USE_AITER_RMSNORM: bool = True
VLLM_ROCM_USE_AITER_MLA: bool = True
VLLM_ROCM_USE_AITER_MHA: bool = True
VLLM_ROCM_USE_AITER_FP8BMM: bool = True
VLLM_ROCM_USE_SKINNY_GEMM: bool = True
VLLM_ROCM_FP8_PADDING: bool = True
VLLM_ROCM_MOE_PADDING: bool = True
......@@ -131,7 +132,9 @@ if TYPE_CHECKING:
VLLM_TPU_USING_PATHWAYS: bool = False
VLLM_USE_DEEP_GEMM: bool = False
VLLM_USE_DEEP_GEMM_E8M0: bool = True
VLLM_USE_DEEP_GEMM_E8M0_HOPPER: bool = False
VLLM_SKIP_DEEP_GEMM_WARMUP: bool = False
VLLM_USE_FUSED_MOE_GROUPED_TOPK: bool = True
VLLM_USE_FLASHINFER_MOE_FP8: bool = False
VLLM_USE_FLASHINFER_MOE_FP4: bool = False
VLLM_FLASHINFER_MOE_BACKEND: str = "throughput"
......@@ -159,9 +162,12 @@ if TYPE_CHECKING:
VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: bool = False
VLLM_ENABLE_RESPONSES_API_STORE: bool = False
VLLM_USE_TRTLLM_ATTENTION: Optional[str] = None
VLLM_HAS_FLASHINFER_CUBIN: bool = False
VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8: bool = False
VLLM_USE_FLASHINFER_MOE_MXFP4_BF16: bool = False
VLLM_ALLREDUCE_USE_SYMM_MEM: bool = False
VLLM_TUNED_CONFIG_FOLDER: Optional[str] = None
VLLM_DISABLE_PAD_FOR_CUDAGRAPH: bool = False
def get_default_cache_root():
......@@ -465,11 +471,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda: bool(int(os.environ["VLLM_USE_FLASHINFER_SAMPLER"]))
if "VLLM_USE_FLASHINFER_SAMPLER" in os.environ else None,
# If set, vllm will force flashinfer to use tensor cores;
# otherwise will use heuristic based on model architecture.
"VLLM_FLASHINFER_FORCE_TENSOR_CORES":
lambda: bool(int(os.getenv("VLLM_FLASHINFER_FORCE_TENSOR_CORES", "0"))),
# Pipeline stage partition strategy
"VLLM_PP_LAYER_PARTITION":
lambda: os.getenv("VLLM_PP_LAYER_PARTITION", None),
......@@ -667,11 +668,14 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_LORA_RESOLVER_CACHE_DIR":
lambda: os.getenv("VLLM_LORA_RESOLVER_CACHE_DIR", None),
# Enables torch profiler if set. Path to the directory where torch profiler
# traces are saved. Note that it must be an absolute path.
# Enables torch profiler if set.
# Both AsyncLLM's CPU traces as well as workers'
# traces (CPU & GPU) will be saved under this directory.
# Note that it must be an absolute path.
"VLLM_TORCH_PROFILER_DIR":
lambda: (None if os.getenv("VLLM_TORCH_PROFILER_DIR", None) is None else os
.path.expanduser(os.getenv("VLLM_TORCH_PROFILER_DIR", "."))),
.path.abspath(os.path.expanduser(os.getenv(
"VLLM_TORCH_PROFILER_DIR", ".")))),
# Enable torch profiler to record shapes if set
# VLLM_TORCH_PROFILER_RECORD_SHAPES=1. If not set, torch profiler will
......@@ -771,6 +775,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda: (os.getenv("VLLM_ROCM_USE_AITER_MHA", "True").lower() in
("true", "1")),
# Whether to use aiter triton fp8 bmm kernel
# By default is enabled.
"VLLM_ROCM_USE_AITER_FP8BMM":
lambda: (os.getenv("VLLM_ROCM_USE_AITER_FP8BMM", "True").lower() in
("true", "1")),
# use rocm skinny gemms
"VLLM_ROCM_USE_SKINNY_GEMM":
lambda: (os.getenv("VLLM_ROCM_USE_SKINNY_GEMM", "True").lower() in
......@@ -953,9 +963,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM", "0"))),
# Whether to use E8M0 scaling when DeepGEMM is used on Blackwell GPUs.
# E8M0 is faster on B200 but may reduce accuracy.
"VLLM_USE_DEEP_GEMM_E8M0":
lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM_E8M0", "1"))),
# TODO(wentao): unify the two E8M0 flags after verifying the correctness.
# Whether to use E8M0 scaling when DeepGEMM is used on Hopper GPUs.
"VLLM_USE_DEEP_GEMM_E8M0_HOPPER":
lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM_E8M0_HOPPER", "0"))),
# DeepGemm JITs the kernels on-demand. The warmup attempts to make DeepGemm
# JIT all the required kernels before model execution so there is no
# JIT'ing in the hot-path. However, this warmup increases the engine
......@@ -964,6 +977,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_SKIP_DEEP_GEMM_WARMUP":
lambda: bool(int(os.getenv("VLLM_SKIP_DEEP_GEMM_WARMUP", "0"))),
# Whether to use fused grouped_topk used for MoE expert selection.
"VLLM_USE_FUSED_MOE_GROUPED_TOPK":
lambda: bool(int(os.getenv("VLLM_USE_FUSED_MOE_GROUPED_TOPK", "1"))),
# Allow use of FlashInfer MoE kernels for fused moe ops.
"VLLM_USE_FLASHINFER_MOE_FP8":
lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_MOE_FP8", "0"))),
......@@ -1042,6 +1059,16 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE":
lambda: int(os.getenv("VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE", "163840")),
# Specifies the thresholds of the communicated tensor sizes under which
# vllm should use flashinfer fused allreduce. The variable should be a
# JSON with the following format:
# { <world size>: <max size in mb> }
# Unspecified world sizes will fallback to
# { 2: 64, 4: 1, <everything else>: 0.5 }
"VLLM_FLASHINFER_ALLREDUCE_FUSION_THRESHOLDS_MB":
lambda: json.loads(os.getenv(
"VLLM_FLASHINFER_ALLREDUCE_FUSION_THRESHOLDS_MB", "{}")),
# MoE routing strategy selector.
# See `RoutingSimulator.get_available_strategies()` # for available
# strategies.
......@@ -1108,6 +1135,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_TRTLLM_ATTENTION":
lambda: os.getenv("VLLM_USE_TRTLLM_ATTENTION", None),
# If set, it means we pre-downloaded cubin files and flashinfer will
# read the cubin files directly.
"VLLM_HAS_FLASHINFER_CUBIN":
lambda: os.getenv("VLLM_HAS_FLASHINFER_CUBIN", False),
# If set to 1, force the use of TRTLLM FP4 GEMM backend in flashinfer.
# Otherwise, uses the first available of: flashinfer cutlass GEMM,
# vllm cutlass GEMM, marlin GEMM.
......@@ -1120,6 +1152,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_ENABLE_CUDAGRAPH_GC":
lambda: bool(int(os.getenv("VLLM_ENABLE_CUDAGRAPH_GC", "0"))),
# Disable padding to CUDA graph capture batch sizes.
# TODO(wentao): https://github.com/vllm-project/vllm/issues/23378
# After the issue is fixed, we can remove this flag.
"VLLM_DISABLE_PAD_FOR_CUDAGRAPH":
lambda: bool(int(os.getenv("VLLM_DISABLE_PAD_FOR_CUDAGRAPH", "0"))),
# Used to force set up loopback IP
"VLLM_LOOPBACK_IP":
lambda: os.getenv("VLLM_LOOPBACK_IP", ""),
......@@ -1153,6 +1191,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_ENABLE_RESPONSES_API_STORE":
lambda: bool(int(os.getenv("VLLM_ENABLE_RESPONSES_API_STORE", "0"))),
# Whether to use pytorch symmetric memory for allreduce
"VLLM_ALLREDUCE_USE_SYMM_MEM":
lambda: bool(int(os.getenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "0"))),
# Allows vllm to find tuned config under customized folder
"VLLM_TUNED_CONFIG_FOLDER":
lambda: os.getenv("VLLM_TUNED_CONFIG_FOLDER", None),
......@@ -1218,10 +1260,12 @@ def compute_hash() -> str:
"VLLM_USE_AITER_UNIFIED_ATTENTION",
"VLLM_ATTENTION_BACKEND",
"VLLM_USE_FLASHINFER_SAMPLER",
"VLLM_FLASHINFER_FORCE_TENSOR_CORES",
"VLLM_DISABLED_KERNELS",
"VLLM_USE_DEEP_GEMM",
"VLLM_USE_DEEP_GEMM_E8M0",
"VLLM_USE_DEEP_GEMM_E8M0_HOPPER",
"VLLM_USE_TRTLLM_FP4_GEMM",
"VLLM_USE_FUSED_MOE_GROUPED_TOPK",
"VLLM_USE_FLASHINFER_MOE_FP8",
"VLLM_USE_FLASHINFER_MOE_FP4",
"VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8",
......@@ -1235,6 +1279,7 @@ def compute_hash() -> str:
"VLLM_ROCM_USE_AITER_RMSNORM",
"VLLM_ROCM_USE_AITER_MLA",
"VLLM_ROCM_USE_AITER_MHA",
"VLLM_ROCM_USE_AITER_FP8BMM",
"VLLM_ROCM_USE_SKINNY_GEMM",
"VLLM_ROCM_FP8_PADDING",
"VLLM_ROCM_MOE_PADDING",
......
......@@ -101,7 +101,7 @@ class MultiprocessingDistributedExecutor(DistributedExecutorBase):
result_handler.start()
self.worker_monitor.start()
# Set up signal handlers to shutdown the executor cleanly
# Set up signal handlers to shut down the executor cleanly
# sometimes gc does not work well
self.driver_worker = WorkerWrapperBase(self.vllm_config, 0)
......
......@@ -4,11 +4,12 @@
from array import array
from typing import Any, Type
from vllm.multimodal.inputs import MultiModalKwargs
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE
def encode_hook(obj: Any) -> Any:
"""Custom msgspec enc hook that supports array types.
"""Custom msgspec enc hook that supports array types and MultiModalKwargs.
See https://jcristharif.com/msgspec/api.html#msgspec.msgpack.Encoder
"""
......@@ -17,10 +18,12 @@ def encode_hook(obj: Any) -> Any:
f"vLLM array type should use '{VLLM_TOKEN_ID_ARRAY_TYPE}' type. "
f"Given array has a type code of {obj.typecode}.")
return obj.tobytes()
if isinstance(obj, MultiModalKwargs):
return dict(obj)
def decode_hook(type: Type, obj: Any) -> Any:
"""Custom msgspec dec hook that supports array types.
"""Custom msgspec dec hook that supports array types and MultiModalKwargs.
See https://jcristharif.com/msgspec/api.html#msgspec.msgpack.Encoder
"""
......@@ -28,3 +31,5 @@ def decode_hook(type: Type, obj: Any) -> Any:
deserialized = array(VLLM_TOKEN_ID_ARRAY_TYPE)
deserialized.frombytes(obj)
return deserialized
if type is MultiModalKwargs:
return MultiModalKwargs(obj)
......@@ -10,6 +10,7 @@ import msgspec
import vllm.platforms
from vllm.config import ParallelConfig
from vllm.distributed import get_pp_group
from vllm.executor.msgspec_utils import decode_hook, encode_hook
from vllm.logger import init_logger
from vllm.platforms import current_platform
......@@ -136,6 +137,11 @@ try:
scheduler_output, intermediate_tensors)
if isinstance(output, IntermediateTensors):
output = scheduler_output, output
elif not get_pp_group().is_last_rank:
# Case where there are no scheduled requests
# but may still be finished requests.
assert not output or not output.req_ids
output = scheduler_output, None
return output
def override_env_vars(self, vars: Dict[str, str]):
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from .data import (DecoderOnlyInputs, EmbedsInputs, EmbedsPrompt,
from .data import (DataPrompt, DecoderOnlyInputs, EmbedsInputs, EmbedsPrompt,
EncoderDecoderInputs, ExplicitEncoderDecoderPrompt,
ProcessorInputs, PromptType, SingletonInputs,
SingletonPrompt, TextPrompt, TokenInputs, TokensPrompt,
......@@ -18,6 +18,7 @@ target model.
"""
__all__ = [
"DataPrompt",
"TextPrompt",
"TokensPrompt",
"PromptType",
......
......@@ -7,7 +7,8 @@ import torch
from typing_extensions import NotRequired, TypedDict, TypeIs, TypeVar
if TYPE_CHECKING:
from vllm.multimodal.inputs import MultiModalDataDict, MultiModalInputs
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalInputs,
MultiModalUUIDDict)
class TextPrompt(TypedDict):
......@@ -30,6 +31,15 @@ class TextPrompt(TypedDict):
to pass the mm_processor_kwargs to each of them.
"""
multi_modal_uuids: NotRequired["MultiModalUUIDDict"]
"""
Optional user-specified UUIDs for multimodal items, mapped by modality.
Lists must match the number of items per modality and may contain `None`.
For `None` entries, the hasher will compute IDs automatically; non-None
entries override the default hashes for caching, and MUST be unique per
multimodal item.
"""
cache_salt: NotRequired[str]
"""
Optional cache salt to be used for prefix caching.
......@@ -59,6 +69,14 @@ class TokensPrompt(TypedDict):
to pass the mm_processor_kwargs to each of them.
"""
multi_modal_uuids: NotRequired["MultiModalUUIDDict"]
"""
Optional user-specified UUIDs for multimodal items, mapped by modality.
Lists must match the number of items per modality and may contain `None`.
For `None` entries, the hasher will compute IDs automatically; non-None
entries override the default hashes for caching.
"""
cache_salt: NotRequired[str]
"""
Optional cache salt to be used for prefix caching.
......@@ -77,6 +95,16 @@ class EmbedsPrompt(TypedDict):
"""
class DataPrompt(TypedDict):
"""Represents generic inputs handled by IO processor plugins."""
data: Any
"""The input data"""
data_format: str
"""The input data format"""
SingletonPrompt = Union[str, TextPrompt, TokensPrompt, EmbedsPrompt]
"""
Set of possible schemas for a single prompt:
......@@ -174,9 +202,6 @@ class TokenInputs(TypedDict):
prompt_token_ids: list[int]
"""The token IDs of the prompt."""
token_type_ids: NotRequired[list[int]]
"""The token type IDs of the prompt."""
prompt: NotRequired[str]
"""
The original prompt text corresponding to the token IDs, if available.
......@@ -190,7 +215,6 @@ class TokenInputs(TypedDict):
def token_inputs(
prompt_token_ids: list[int],
token_type_ids: Optional[list[int]] = None,
prompt: Optional[str] = None,
cache_salt: Optional[str] = None,
) -> TokenInputs:
......@@ -200,8 +224,6 @@ def token_inputs(
if prompt is not None:
inputs["prompt"] = prompt
if token_type_ids is not None:
inputs["token_type_ids"] = token_type_ids
if cache_salt is not None:
inputs["cache_salt"] = cache_salt
......
......@@ -11,8 +11,9 @@ from vllm.config import ModelConfig
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.cache import BaseMultiModalProcessorCache
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs,
MultiModalInputs)
MultiModalInputs, MultiModalUUIDDict)
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
......@@ -32,12 +33,14 @@ class InputPreprocessor:
model_config: ModelConfig,
tokenizer: Optional[TokenizerGroup],
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
mm_processor_cache: Optional[BaseMultiModalProcessorCache] = None,
) -> None:
super().__init__()
self.model_config = model_config
self.tokenizer = tokenizer
self.mm_registry = mm_registry
self.mm_processor_cache = mm_processor_cache
def get_tokenizer_group(self) -> TokenizerGroup:
if self.tokenizer is None:
......@@ -254,7 +257,9 @@ class InputPreprocessor:
mm_processor_kwargs: Optional[Mapping[str, object]],
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
return_mm_hashes: bool = False,
*,
mm_hash_overrides: Optional[Union[dict[str, list[str]],
MultiModalUUIDDict]] = None,
) -> MultiModalInputs:
"""
Apply the model's multi-modal processor to a multi-modal prompt,
......@@ -262,17 +267,22 @@ class InputPreprocessor:
"""
tokenizer = self._get_mm_tokenizer(lora_request)
mm_processor = self.mm_registry.create_processor(self.model_config,
tokenizer=tokenizer)
mm_processor = self.mm_registry.create_processor(
self.model_config,
tokenizer=tokenizer,
cache=self.mm_processor_cache,
)
if mm_processor_kwargs is None:
mm_processor_kwargs = {}
return mm_processor.apply(prompt,
return mm_processor.apply(
prompt,
mm_data,
hf_processor_mm_kwargs=mm_processor_kwargs,
tokenization_kwargs=tokenization_kwargs,
return_mm_hashes=return_mm_hashes)
mm_hash_overrides=mm_hash_overrides,
)
async def _process_multimodal_async(
self,
......@@ -281,7 +291,9 @@ class InputPreprocessor:
mm_processor_kwargs: Optional[Mapping[str, object]],
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
return_mm_hashes: bool = False,
*,
mm_hash_overrides: Optional[Union[dict[str, list[str]],
MultiModalUUIDDict]] = None,
) -> MultiModalInputs:
"""
Async version of
......@@ -289,16 +301,22 @@ class InputPreprocessor:
"""
tokenizer = await self._get_mm_tokenizer_async(lora_request)
mm_processor = self.mm_registry.create_processor(self.model_config,
tokenizer=tokenizer)
mm_processor = self.mm_registry.create_processor(
self.model_config,
tokenizer=tokenizer,
cache=self.mm_processor_cache,
)
if mm_processor_kwargs is None:
mm_processor_kwargs = {}
return mm_processor.apply(prompt,
return mm_processor.apply(
prompt,
mm_data,
hf_processor_mm_kwargs=mm_processor_kwargs,
tokenization_kwargs=tokenization_kwargs,
return_mm_hashes=return_mm_hashes)
mm_hash_overrides=mm_hash_overrides,
)
def _process_embeds(
self,
......@@ -330,15 +348,33 @@ class InputPreprocessor:
) -> EmbedsInputs:
return self._process_embeds(parsed_content)
def _truncate_inputs(
self,
inputs: list[int],
tokenization_kwargs: Optional[dict[str, Any]] = None) -> list[int]:
if not tokenization_kwargs or "truncation" not in \
tokenization_kwargs or self.tokenizer is None:
return inputs
max_length = tokenization_kwargs["max_length"]
if self.tokenizer.truncation_side == "left":
return inputs[-max_length:]
else:
return inputs[:max_length]
def _process_tokens(
self,
parsed_content: TokensPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
return_mm_hashes: bool = False,
*,
mm_hash_overrides: Optional[Union[dict[str, list[str]],
MultiModalUUIDDict]] = None,
) -> Union[TokenInputs, MultiModalInputs]:
prompt_token_ids = parsed_content["prompt_token_ids"]
token_type_ids = parsed_content.get("token_type_ids")
prompt_token_ids = self._truncate_inputs(
parsed_content["prompt_token_ids"], tokenization_kwargs)
inputs: Union[TokenInputs, MultiModalInputs]
if multi_modal_data := parsed_content.get("multi_modal_data"):
......@@ -348,13 +384,10 @@ class InputPreprocessor:
parsed_content.get("mm_processor_kwargs"),
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
return_mm_hashes=return_mm_hashes,
mm_hash_overrides=mm_hash_overrides,
)
else:
inputs = token_inputs(
prompt_token_ids=prompt_token_ids,
token_type_ids=token_type_ids,
)
inputs = token_inputs(prompt_token_ids=prompt_token_ids)
if cache_salt := parsed_content.get("cache_salt"):
inputs["cache_salt"] = cache_salt
......@@ -366,10 +399,12 @@ class InputPreprocessor:
parsed_content: TokensPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
return_mm_hashes: bool = False,
*,
mm_hash_overrides: Optional[Union[dict[str, list[str]],
MultiModalUUIDDict]] = None,
) -> Union[TokenInputs, MultiModalInputs]:
prompt_token_ids = parsed_content["prompt_token_ids"]
token_type_ids = parsed_content.get("token_type_ids")
prompt_token_ids = self._truncate_inputs(
parsed_content["prompt_token_ids"], tokenization_kwargs)
inputs: Union[TokenInputs, MultiModalInputs]
if multi_modal_data := parsed_content.get("multi_modal_data"):
......@@ -379,13 +414,10 @@ class InputPreprocessor:
parsed_content.get("mm_processor_kwargs"),
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
return_mm_hashes=return_mm_hashes,
mm_hash_overrides=mm_hash_overrides,
)
else:
inputs = token_inputs(
prompt_token_ids=prompt_token_ids,
token_type_ids=token_type_ids,
)
inputs = token_inputs(prompt_token_ids=prompt_token_ids, )
if cache_salt := parsed_content.get("cache_salt"):
inputs["cache_salt"] = cache_salt
......@@ -397,7 +429,9 @@ class InputPreprocessor:
parsed_content: TextPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
return_mm_hashes: bool = False,
*,
mm_hash_overrides: Optional[Union[dict[str, list[str]],
MultiModalUUIDDict]] = None,
) -> Union[TokenInputs, MultiModalInputs]:
prompt_text = parsed_content["prompt"]
......@@ -409,7 +443,7 @@ class InputPreprocessor:
parsed_content.get("mm_processor_kwargs"),
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
return_mm_hashes=return_mm_hashes,
mm_hash_overrides=mm_hash_overrides,
)
else:
prompt_token_ids = self._tokenize_prompt(
......@@ -432,7 +466,9 @@ class InputPreprocessor:
parsed_content: TextPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
return_mm_hashes: bool = False,
*,
mm_hash_overrides: Optional[Union[dict[str, list[str]],
MultiModalUUIDDict]] = None,
) -> Union[TokenInputs, MultiModalInputs]:
prompt_text = parsed_content["prompt"]
......@@ -444,7 +480,7 @@ class InputPreprocessor:
parsed_content.get("mm_processor_kwargs"),
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
return_mm_hashes=return_mm_hashes,
mm_hash_overrides=mm_hash_overrides,
)
else:
prompt_token_ids = await self._tokenize_prompt_async(
......@@ -467,7 +503,9 @@ class InputPreprocessor:
prompt: SingletonPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
return_mm_hashes: bool = False,
*,
mm_hash_overrides: Optional[Union[dict[str, list[str]],
MultiModalUUIDDict]] = None,
) -> SingletonInputs:
"""
Extract the singleton inputs from a prompt.
......@@ -476,7 +514,6 @@ class InputPreprocessor:
* prompt: single encoder or decoder input prompt
* lora_request: this is only valid for decoder prompts
* return_mm_hashes: whether to return multimodal hashes
Returns:
......@@ -490,21 +527,21 @@ class InputPreprocessor:
return self._process_tokens(
parsed["content"],
lora_request=lora_request,
return_mm_hashes=return_mm_hashes,
mm_hash_overrides=mm_hash_overrides,
)
if parsed["type"] == "text":
return self._process_text(
parsed["content"],
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
return_mm_hashes=return_mm_hashes,
mm_hash_overrides=mm_hash_overrides,
)
if parsed["type"] == "str":
return self._process_text(
TextPrompt(prompt=parsed["content"]),
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
return_mm_hashes=return_mm_hashes,
mm_hash_overrides=mm_hash_overrides,
)
assert_never(parsed)
......@@ -514,7 +551,9 @@ class InputPreprocessor:
prompt: SingletonPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
return_mm_hashes: bool = False,
*,
mm_hash_overrides: Optional[Union[dict[str, list[str]],
MultiModalUUIDDict]] = None,
) -> SingletonInputs:
"""
Async version of
......@@ -528,21 +567,21 @@ class InputPreprocessor:
return await self._process_tokens_async(
parsed["content"],
lora_request=lora_request,
return_mm_hashes=return_mm_hashes,
mm_hash_overrides=mm_hash_overrides,
)
if parsed["type"] == "text":
return await self._process_text_async(
parsed["content"],
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
return_mm_hashes=return_mm_hashes,
mm_hash_overrides=mm_hash_overrides,
)
if parsed["type"] == "str":
return await self._process_text_async(
TextPrompt(prompt=parsed["content"]),
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
return_mm_hashes=return_mm_hashes,
mm_hash_overrides=mm_hash_overrides,
)
assert_never(parsed)
......@@ -652,6 +691,9 @@ class InputPreprocessor:
self,
prompt: PromptType,
tokenization_kwargs: Optional[dict[str, Any]] = None,
*,
mm_hash_overrides: Optional[Union[dict[str, list[str]],
MultiModalUUIDDict]] = None,
) -> EncoderDecoderInputs:
"""
For encoder/decoder models only:
......@@ -693,6 +735,7 @@ class InputPreprocessor:
encoder_inputs = self._prompt_to_llm_inputs(
prompt["encoder_prompt"],
tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
)
if (decoder_input := prompt["decoder_prompt"]) is None:
decoder_inputs = None
......@@ -708,6 +751,7 @@ class InputPreprocessor:
inputs = self._prompt_to_llm_inputs(
prompt,
tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
)
if self.model_config.is_multimodal_model:
# Encoder-Decoder Multimodal model
......@@ -723,6 +767,9 @@ class InputPreprocessor:
self,
prompt: PromptType,
tokenization_kwargs: Optional[dict[str, Any]] = None,
*,
mm_hash_overrides: Optional[Union[dict[str, list[str]],
MultiModalUUIDDict]] = None,
) -> EncoderDecoderInputs:
"""
Async version of
......@@ -735,6 +782,7 @@ class InputPreprocessor:
encoder_task = self._prompt_to_llm_inputs_async(
prompt["encoder_prompt"],
tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
)
if (decoder_input := prompt["decoder_prompt"]) is None:
......@@ -744,6 +792,7 @@ class InputPreprocessor:
decoder_task = self._prompt_to_llm_inputs_async(
decoder_input,
tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
)
encoder_inputs, decoder_inputs = await asyncio.gather(
......@@ -759,6 +808,7 @@ class InputPreprocessor:
inputs = await self._prompt_to_llm_inputs_async(
prompt,
tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
)
if self.model_config.is_multimodal_model:
# Encoder-Decoder Multimodal model
......@@ -785,7 +835,9 @@ class InputPreprocessor:
prompt: SingletonPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
return_mm_hashes: bool = False,
*,
mm_hash_overrides: Optional[Union[dict[str, list[str]],
MultiModalUUIDDict]] = None,
) -> DecoderOnlyInputs:
"""
For decoder-only models:
......@@ -796,7 +848,6 @@ class InputPreprocessor:
* prompt: input prompt
* lora_request
* return_mm_hashes
Returns:
......@@ -807,7 +858,7 @@ class InputPreprocessor:
prompt,
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
return_mm_hashes=return_mm_hashes,
mm_hash_overrides=mm_hash_overrides,
)
return self._build_decoder_only_llm_inputs(prompt_comps)
......@@ -817,7 +868,9 @@ class InputPreprocessor:
prompt: SingletonPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
return_mm_hashes: bool = False,
*,
mm_hash_overrides: Optional[Union[dict[str, list[str]],
MultiModalUUIDDict]] = None,
) -> DecoderOnlyInputs:
"""
Async version of
......@@ -827,7 +880,7 @@ class InputPreprocessor:
prompt,
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
return_mm_hashes=return_mm_hashes,
mm_hash_overrides=mm_hash_overrides,
)
return self._build_decoder_only_llm_inputs(prompt_comps)
......@@ -837,17 +890,19 @@ class InputPreprocessor:
prompt: PromptType,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
return_mm_hashes: bool = False,
*,
mm_hash_overrides: Optional[Union[dict[str, list[str]],
MultiModalUUIDDict]] = None,
) -> ProcessorInputs:
"""Preprocess the input prompt."""
if self.model_config.is_encoder_decoder:
assert not return_mm_hashes, (
"Multimodal hashes for encoder-decoder models should not be ",
"returned until they are supported on vLLM V1.")
# Encoder-decoder model requires special mapping of
# input prompts to encoder & decoder
# input prompts to encoder & decoder.
return self._process_encoder_decoder_prompt(
prompt, tokenization_kwargs)
prompt,
tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
)
if is_explicit_encoder_decoder_prompt(prompt):
raise ValueError("Cannot pass encoder-decoder prompt "
......@@ -858,7 +913,7 @@ class InputPreprocessor:
prompt,
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
return_mm_hashes=return_mm_hashes,
mm_hash_overrides=mm_hash_overrides,
)
async def preprocess_async(
......@@ -866,19 +921,22 @@ class InputPreprocessor:
prompt: PromptType,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
return_mm_hashes: bool = False,
*,
mm_hash_overrides: Optional[Union[dict[str, list[str]],
MultiModalUUIDDict]] = None,
) -> ProcessorInputs:
"""
Async version of
[`preprocess`][vllm.inputs.preprocess.InputPreprocessor.preprocess].
"""
if self.model_config.is_encoder_decoder:
assert not return_mm_hashes, (
"Multimodal hashes for encoder-decoder models should not be ",
"returned until they are supported on vLLM V1.")
# Encoder-decoder model requires special mapping of
# input prompts to encoder & decoder
return await self._process_encoder_decoder_prompt_async(prompt)
# input prompts to encoder & decoder.
return await self._process_encoder_decoder_prompt_async(
prompt,
tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
)
if is_explicit_encoder_decoder_prompt(prompt):
raise ValueError("Cannot pass encoder-decoder prompt "
......@@ -889,5 +947,9 @@ class InputPreprocessor:
prompt,
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
return_mm_hashes=return_mm_hashes,
mm_hash_overrides=mm_hash_overrides,
)
def clear_cache(self) -> None:
if self.mm_processor_cache is not None:
self.mm_processor_cache.clear_cache()
......@@ -223,20 +223,26 @@ class InputRegistry:
The model is identified by ``model_config``.
"""
# Avoid circular import
from vllm.multimodal.cache import processor_only_cache_from_config
from vllm.sequence import SequenceData
if not model_config.is_multimodal_model:
seq_data = SequenceData.from_prompt_token_counts((0, seq_len))
return DummyData(seq_data=seq_data)
cache = processor_only_cache_from_config(model_config, mm_registry)
# Encoder dummy data does not contain multi-modal data
if is_encoder_data:
enc_data = mm_registry.get_encoder_dummy_data(
model_config, seq_len)
enc_data = mm_registry.get_encoder_dummy_data(model_config,
seq_len,
cache=cache)
seq_data = SequenceData.from_seqs(enc_data.prompt_token_ids)
return DummyData(seq_data=seq_data)
dec_data = mm_registry.get_decoder_dummy_data(model_config, seq_len)
dec_data = mm_registry.get_decoder_dummy_data(model_config,
seq_len,
cache=cache)
return DummyData(
seq_data=SequenceData.from_seqs(dec_data.prompt_token_ids),
......
......@@ -48,9 +48,6 @@ def _get_lora_device(base_layer: nn.Module) -> torch.device:
# GPTQ/AWQ
elif hasattr(base_layer, "qweight"):
return base_layer.qweight.device
# marlin
elif hasattr(base_layer, "B"):
return base_layer.B.device
# HQQ marlin
elif hasattr(base_layer, "W_q"):
return base_layer.W_q.device
......@@ -608,7 +605,7 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
"""ColumnParallelLinear layer that is composed of 2 sublayers (slices)
packed together (eg. gate_proj + up_proj -> gate_up_proj).
packed together (e.g. gate_proj + up_proj -> gate_up_proj).
This means we have 2 LoRAs, each applied to one half of the layer.
......
......@@ -207,6 +207,7 @@ class LoRAModel(AdapterModel):
"""
lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors")
lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin")
lora_pt_file_path = os.path.join(lora_dir, "adapter_model.pt")
new_embeddings_tensor_path = os.path.join(
lora_dir, "new_embeddings.safetensors")
new_embeddings_bin_file_path = os.path.join(lora_dir,
......@@ -255,9 +256,10 @@ class LoRAModel(AdapterModel):
check_unexpected_modules(f)
for module in f.keys(): # noqa
tensors[module] = f.get_tensor(module)
elif os.path.isfile(lora_bin_file_path):
# When a bin file is provided, we rely on config to find unexpected
# modules.
elif os.path.isfile(lora_bin_file_path) or os.path.isfile(
lora_pt_file_path):
# When a bin/pt file is provided, we rely on config to find
# unexpected modules.
unexpected_modules = []
target_modules = peft_helper.target_modules
if not isinstance(target_modules, list):
......@@ -279,7 +281,10 @@ class LoRAModel(AdapterModel):
f" target modules in {expected_lora_modules}"
f" but received {unexpected_modules}."
f" Please verify that the loaded LoRA module is correct")
tensors = torch.load(lora_bin_file_path,
lora_file_path = (lora_bin_file_path
if os.path.isfile(lora_bin_file_path) else
lora_pt_file_path)
tensors = torch.load(lora_file_path,
map_location=device,
weights_only=True)
else:
......
......@@ -10,11 +10,14 @@ import torch.nn.functional as F
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils import LazyDict
logger = init_logger(__name__)
@CustomOp.register("fatrelu_and_mul")
class FatreluAndMul(CustomOp):
......@@ -363,6 +366,112 @@ class ReLUSquaredActivation(CustomOp):
return self.forward_native(x)
@CustomOp.register("xielu")
class XIELU(CustomOp):
"""
Applies the xIELU activation function introduced in https://arxiv.org/abs/2411.13010
If the user has installed the nickjbrowning/XIELU, we import xIELU CUDA
Otherwise, we emit a single warning and use xIELU Python
"""
def __init__(
self,
alpha_p_init: float = 0.8,
alpha_n_init: float = 0.8,
beta: float = 0.5,
eps: float = -1e-6,
dtype: torch.dtype = torch.bfloat16,
with_vector_loads: bool = False,
):
super().__init__()
self.alpha_p = nn.Parameter(
torch.log(torch.exp(torch.tensor(alpha_p_init, dtype=dtype)) -
1).unsqueeze(0))
self.alpha_n = nn.Parameter(
torch.log(
torch.exp(torch.tensor(alpha_n_init - beta, dtype=dtype)) -
1).unsqueeze(0))
self.register_buffer("beta", torch.tensor(beta, dtype=dtype))
self.register_buffer("eps", torch.tensor(eps, dtype=dtype))
self.with_vector_loads = with_vector_loads
# Temporary until xIELU CUDA fully implemented
self._beta_scalar = float(self.beta.detach().cpu().float().item())
self._eps_scalar = float(self.eps.detach().cpu().float().item())
self._xielu_cuda_obj = None
try:
import xielu.ops # noqa: F401
self._xielu_cuda_obj = torch.classes.xielu.XIELU()
msg = "Using experimental xIELU CUDA."
try:
from torch._dynamo import allow_in_graph
self._xielu_cuda_fn = allow_in_graph(self._xielu_cuda)
msg += " Enabled torch._dynamo for xIELU CUDA."
except Exception as err:
msg += (f" Could not enable torch._dynamo for xIELU ({err}) - "
"this may result in slower performance.")
self._xielu_cuda_fn = self._xielu_cuda
logger.warning_once(msg)
except Exception as err:
logger.warning_once(
"CUDA-fused xIELU not available (%s) –"
" falling back to a Python version.\n"
"For CUDA xIELU (experimental), `pip install git+https://github.com/nickjbrowning/XIELU`",
str(err),
)
def _xielu_python(self, x: torch.Tensor) -> torch.Tensor:
alpha_p = nn.functional.softplus(self.alpha_p)
alpha_n = self.beta + nn.functional.softplus(self.alpha_n)
return torch.where(
x > 0,
alpha_p * x * x + self.beta * x,
(torch.expm1(torch.min(x, self.eps)) - x) * alpha_n +
self.beta * x,
)
def _xielu_cuda(self, x: torch.Tensor) -> torch.Tensor:
"""Firewall function to prevent torch.compile from seeing .item()"""
assert self._xielu_cuda_obj is not None, (
"XIELU CUDA object must not be None")
original_shape = x.shape
# CUDA kernel expects 3D tensors, reshape if needed
while x.dim() < 3:
x = x.unsqueeze(0)
if x.dim() > 3:
x = x.view(-1, 1, x.size(-1))
if original_shape != x.shape:
logger.warning_once(
"Warning: xIELU input tensor expects 3 dimensions"
" but got (shape: %s). Reshaping to (shape: %s).",
original_shape,
x.shape,
)
result = self._xielu_cuda_obj.forward(
x,
self.alpha_p,
self.alpha_n,
# Temporary until xIELU CUDA fully implemented ->
# self.{beta,eps}.item()
self._beta_scalar,
self._eps_scalar,
self.with_vector_loads,
)
return result.view(original_shape)
def forward(self, input: torch.Tensor) -> torch.Tensor:
if self._xielu_cuda_obj is not None and input.is_cuda:
if not torch._dynamo.is_compiling():
return self._xielu_cuda_fn(input)
else:
logger.warning_once(
"torch._dynamo is compiling, using Python version of xIELU."
)
return self._xielu_python(input)
class ScaledActivation(nn.Module):
"""An activation function with post-scale parameters.
......@@ -422,12 +531,25 @@ _ACTIVATION_REGISTRY = LazyDict({
lambda: nn.SiLU(),
"quick_gelu":
lambda: QuickGELU(),
"tanh":
lambda: nn.Tanh(),
"sigmoid":
lambda: nn.Sigmoid(),
"xielu":
lambda: XIELU(),
})
def get_act_fn(act_fn_name: str) -> nn.Module:
"""Get an activation function by name."""
act_fn_name = act_fn_name.lower()
if act_fn_name.startswith("torch.nn.modules."):
activation_name = act_fn_name.split(".")[-1]
if activation_name == "identity":
return nn.Identity()
act_fn_name = activation_name
if act_fn_name not in _ACTIVATION_REGISTRY:
raise ValueError(
f"Activation function {act_fn_name!r} is not supported.")
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Base class for attention-like layers."""
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
class AttentionLayerBase(ABC):
"""
Base class for attention-like layers (Attention, Mamba, etc.)
that support the v1 engine.
This provides a common interface for getting attention backends
from different layer types.
"""
@abstractmethod
def get_attn_backend(self) -> type["AttentionBackend"]:
"""Get the attention backend class for this layer."""
pass
......@@ -12,7 +12,7 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
from vllm.triton_utils import tl, triton
from vllm.utils.deep_gemm import (fp8_m_grouped_gemm_nt_masked,
is_blackwell_deep_gemm_e8m0_used)
is_deep_gemm_e8m0_used)
logger = init_logger(__name__)
......@@ -70,53 +70,51 @@ def _silu_mul_fp8_quant_deep_gemm(
# number of valid tokens for this expert
n_tokens = tl.load(counts_ptr + e * stride_counts_e).to(tl.int64)
cols = tl.arange(0, BLOCK)
cols = cols.to(tl.int64)
mask_h = cols < BLOCK
cols = tl.arange(0, BLOCK).to(tl.int64)
mask = cols < BLOCK
base_input_offset = e * stride_i_e + g * GROUP_SIZE * stride_i_h
base_gate_offset = base_input_offset + cols * stride_i_h
base_up_offset = base_input_offset + H * stride_i_h + cols * stride_i_h
base_yq_offset = (e * stride_yq_e + g * GROUP_SIZE * stride_yq_h +
cols * stride_yq_h)
base_ys_offset = e * stride_ys_e + g * stride_ys_g
for t in tl.range(0, n_tokens, num_stages=NUM_STAGES):
base_i_offset = (e * stride_i_e + t * stride_i_t +
g * GROUP_SIZE * stride_i_h)
base_yq_offset = (e * stride_yq_e + t * stride_yq_t +
g * GROUP_SIZE * stride_yq_h)
base_ys_offset = e * stride_ys_e + t * stride_ys_t + g * stride_ys_g
mask = mask_h
x = tl.load(input_ptr + base_i_offset + cols * stride_i_h,
gate = tl.load(input_ptr + base_gate_offset + t * stride_i_t,
mask=mask,
other=0.0).to(tl.float32)
y2 = tl.load(input_ptr + base_i_offset + H * stride_i_h +
cols * stride_i_h,
up = tl.load(input_ptr + base_up_offset + t * stride_i_t,
mask=mask,
other=0.0).to(tl.float32)
other=0.0)
x = x * (1.0 / (1.0 + tl.exp(-x)))
y = x * y2
gate = gate * (1.0 / (1.0 + tl.exp(-gate)))
y = gate * up
y_s = tl.maximum(tl.max(tl.abs(y)), eps) / fp8_max
if use_ue8m0:
y_s = tl.exp2(tl.ceil(tl.log2(y_s)))
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
scale_raw = _absmax / fp8_max
y_s = tl.math.exp2(tl.ceil(
tl.log2(scale_raw))) if use_ue8m0 else scale_raw
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
tl.store(y_q_ptr + base_yq_offset + cols * stride_yq_h, y_q, mask=mask)
tl.store(y_s_ptr + base_ys_offset, y_s)
tl.store(y_q_ptr + base_yq_offset + t * stride_yq_t, y_q, mask=mask)
tl.store(y_s_ptr + base_ys_offset + t * stride_ys_t, y_s)
def silu_mul_fp8_quant_deep_gemm(
y: torch.Tensor, # (E, T, 2*H) float32
y: torch.Tensor, # (E, T, 2*H)
tokens_per_expert: torch.Tensor, # (E,) number of valid tokens per expert
group_size: int = 128,
eps: float = 1e-10,
):
) -> tuple[torch.Tensor, torch.Tensor]:
"""Quantize silu(y[..., :H]) * y[..., H:] to FP8 with group per-token scales
y has shape (E, T, 2*H). The first half of the last dimension is
silu-activated, multiplied by the second half, then quantized into FP8.
Returns `(y_q, y_s)` where
* `y_q` is the FP8 tensor of shape `(E, T, H)`, same layout as `y[..., :H]`.
* `y_s` has shape `(E, T, H // group_size)` and strides `(T*G, 1, T)`
* `y_q`: FP8 tensor, shape (E, T, H), same layout as y[..., :H]
* `y_s`: FP32 tensor, shape (E, T, H // group_size), strides (T*G, 1, T)
"""
assert y.ndim == 3, "y must be (E, T, 2*H)"
E, T, H2 = y.shape
......@@ -148,7 +146,7 @@ def silu_mul_fp8_quant_deep_gemm(
stride_cnt_e = tokens_per_expert.stride()[0]
# static grid over experts and H-groups.
# Static grid over experts and H-groups.
# A loop inside the kernel handles the token dim
grid = (E * G, )
......@@ -176,9 +174,9 @@ def silu_mul_fp8_quant_deep_gemm(
eps,
fp8_min,
fp8_max,
is_blackwell_deep_gemm_e8m0_used(),
is_deep_gemm_e8m0_used(),
BLOCK=group_size,
NUM_STAGES=8,
NUM_STAGES=4,
num_warps=1,
)
......
......@@ -190,12 +190,6 @@ class FusedMoEParallelConfig:
return (self.use_all2all_kernels
and envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency")
@property
def use_flashinfer_cutlass_kernels(self):
return (envs.VLLM_USE_FLASHINFER_MOE_FP4
and has_flashinfer_cutlass_fused_moe()
and envs.VLLM_FLASHINFER_MOE_BACKEND == "throughput")
@staticmethod
def make(tp_size_: int, dp_size_: int,
vllm_parallel_config: ParallelConfig) -> "FusedMoEParallelConfig":
......@@ -404,7 +398,14 @@ class FusedMoEConfig:
@property
def use_flashinfer_cutlass_kernels(self):
return self.moe_parallel_config.use_flashinfer_cutlass_kernels
"""
Whether to use FlashInfer cutlass kernels for NVFP4 MoE.
"""
return (self.quant_config is not None
and self.quant_config.quant_dtype == "nvfp4"
and envs.VLLM_USE_FLASHINFER_MOE_FP4
and has_flashinfer_cutlass_fused_moe()
and envs.VLLM_FLASHINFER_MOE_BACKEND == "throughput")
@staticmethod
def make(
......@@ -450,6 +451,12 @@ class FusedMoEConfig:
if quant_dtype is None and isinstance(quant_config, Fp8Config):
quant_dtype = torch.float8_e4m3fn
from vllm.model_executor.layers.quantization.mxfp4 import (
Mxfp4Config)
if (quant_dtype is None and isinstance(quant_config, Mxfp4Config)
and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8):
quant_dtype = "mxfp8"
from vllm.model_executor.layers.quantization.modelopt import (
ModelOptNvFp4Config)
if quant_dtype is None and isinstance(quant_config,
......
{
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 5
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 5
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 5
},
"256": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 3
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"8192": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"16384": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 2
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment