Unverified Commit 51c38163 authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

model: support Step3V (#8583)


Signed-off-by: default avatarXinyuan Tong <justinning0323@outlook.com>
Co-authored-by: default avatarnnnobody-code <nnnobody@foxmail.com>
Co-authored-by: default avatarispobock <ispobaoke@gmail.com>
Co-authored-by: default avatarQiaolin-Yu <qy254@cornell.edu>
Co-authored-by: default avatarQiaolin-Yu <liin1211@outlook.com>
Co-authored-by: default avatarXinyuan Tong <justinning0323@outlook.com>
Co-authored-by: default avatarXinyuan Tong <115166877+JustinTong0323@users.noreply.github.com>
parent 09f1a247
......@@ -148,7 +148,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s
| `--file-storage-path` | The path of the file storage in backend. | sglang_storage |
| `--enable-cache-report` | Return number of cached tokens in usage.prompt_tokens_details for each openai request. | False |
| `--reasoning-parser` | Specify the parser for reasoning models, supported parsers are: {list(ReasoningParser.DetectorMap.keys())}. | None |
| `--tool-call-parser` | Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', 'llama3', 'deepseekv3', 'pythonic', and 'kimi_k2'. | None |
| `--tool-call-parser` | Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', 'llama3', 'deepseekv3', 'pythonic', 'kimi_k2', 'qwen3_coder', 'glm45', and 'step3'. | None |
## Data parallelism
......
......@@ -5,6 +5,11 @@ from sglang.srt.configs.exaone import ExaoneConfig
from sglang.srt.configs.janus_pro import MultiModalityConfig
from sglang.srt.configs.kimi_vl import KimiVLConfig
from sglang.srt.configs.kimi_vl_moonvit import MoonViTConfig
from sglang.srt.configs.step3_vl import (
Step3TextConfig,
Step3VisionEncoderConfig,
Step3VLConfig,
)
__all__ = [
"ExaoneConfig",
......@@ -14,4 +19,7 @@ __all__ = [
"MultiModalityConfig",
"KimiVLConfig",
"MoonViTConfig",
"Step3VLConfig",
"Step3TextConfig",
"Step3VisionEncoderConfig",
]
......@@ -335,6 +335,8 @@ class ModelConfig:
"num_key_value_heads",
# For ChatGLM:
"multi_query_group_num",
# For Step3
"num_attention_groups",
]
for attr in attributes:
num_kv_heads = getattr(self.hf_text_config, attr, None)
......@@ -644,6 +646,7 @@ multimodal_model_archs = [
"InternS1ForConditionalGeneration",
"Phi4MMForCausalLM",
"VILAForConditionalGeneration",
"Step3VLForConditionalGeneration",
]
......
from typing import Any, Optional, Union
from transformers.configuration_utils import PretrainedConfig
class Step3VisionEncoderConfig(PretrainedConfig):
model_type = "step3_vision_encoder"
def __init__(
self,
hidden_size=1792,
intermediate_size=3072,
output_hidden_size=4096,
num_hidden_layers=63,
num_attention_heads=16,
num_channels=3,
image_size=728,
patch_size=14,
hidden_act="quick_gelu",
layer_norm_eps=1e-5,
**kwargs,
):
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.output_hidden_size = output_hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_channels = num_channels
self.patch_size = patch_size
self.image_size = image_size
self.layer_norm_eps = layer_norm_eps
self.hidden_act = hidden_act
super().__init__(**kwargs)
class Step3TextConfig(PretrainedConfig):
model_type = "step3_text"
architectures = ["Step3TextForCausalLM"]
def __init__(
self,
hidden_size: int = 7168,
intermediate_size: int = 18432,
num_attention_heads: int = 64,
num_attention_groups: int = 1,
num_hidden_layers: int = 61,
max_seq_len: int = 65536,
vocab_size: int = 128815,
rms_norm_eps: float = 1e-5,
moe_intermediate_size: int = 5120,
moe_num_experts: int = 48,
moe_top_k: int = 3,
rope_theta: float = 500000,
rope_scaling: Optional[dict[str, Any]] = None,
max_position_embedding: int = 65536,
share_expert_dim: int = 5120,
share_q_dim: int = 2048,
head_dim: int = 256,
norm_expert_weight: bool = False,
moe_layers_enum: tuple[int] = (
4,
5,
6,
7,
8,
9,
10,
11,
12,
13,
14,
15,
16,
17,
18,
19,
20,
21,
22,
23,
24,
25,
26,
27,
28,
29,
30,
31,
32,
33,
34,
35,
36,
37,
38,
39,
40,
41,
42,
43,
44,
45,
46,
47,
48,
49,
50,
51,
52,
53,
54,
55,
56,
57,
58,
59,
),
**kwargs,
) -> None:
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_attention_heads = num_attention_heads
self.num_attention_groups = num_attention_groups
self.num_hidden_layers = num_hidden_layers
self.max_seq_len = max_seq_len
self.vocab_size = vocab_size
self.rms_norm_eps = rms_norm_eps
self.moe_intermediate_size = moe_intermediate_size
self.moe_num_experts = moe_num_experts
self.moe_top_k = moe_top_k
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.max_position_embedding = max_position_embedding
self.share_expert_dim = share_expert_dim
self.share_q_dim = share_q_dim
self.head_dim = head_dim
self.norm_expert_weight = norm_expert_weight
self.moe_layers_enum = moe_layers_enum
super().__init__(**kwargs)
class Step3VLConfig(PretrainedConfig):
model_type = "step3_vl"
def __init__(
self,
vision_config: Optional[Union[dict, Step3VisionEncoderConfig]] = None,
text_config: Optional[Union[dict, Step3TextConfig]] = None,
understand_projector_stride: int = 1,
projector_bias: bool = True,
image_token_id: int = 128001,
**kwargs,
) -> None:
if vision_config is None:
vision_config = Step3VisionEncoderConfig()
elif isinstance(vision_config, dict):
vision_config = Step3VisionEncoderConfig(**vision_config)
self.vision_config = vision_config
if text_config is None:
text_config = Step3TextConfig()
elif isinstance(text_config, dict):
text_config = Step3TextConfig(**text_config)
self.text_config = text_config
self.understand_projector_stride = understand_projector_stride
self.projector_bias = projector_bias
self.hidden_size = text_config.hidden_size
self.image_token_id = image_token_id
super().__init__(**kwargs)
......@@ -994,6 +994,23 @@ register_conv_template(
)
)
register_conv_template(
Conversation(
name="step3-vl",
system_message="<|begin▁of▁sentence|>You are a helpful assistant",
system_template="{system_message}\n",
roles=(
"<|BOT|>user\n",
"<|BOT|>assistant\n<think>\n",
),
sep="<|EOT|>",
sep_style=SeparatorStyle.NO_COLON_SINGLE,
stop_str="<|EOT|>",
image_token="<im_patch>",
# add_bos=True,
)
)
@register_conv_template_matching_function
def match_internvl(model_path: str):
......@@ -1103,3 +1120,9 @@ def match_vila(model_path: str):
def match_mimo_vl(model_path: str):
if re.search(r"mimo.*vl", model_path, re.IGNORECASE):
return "mimo-vl"
# @register_conv_template_matching_function
# def match_step3(model_path: str):
# if re.search(r"step3", model_path, re.IGNORECASE):
# return "step3-vl"
......@@ -17,6 +17,7 @@ from sglang.srt.function_call.mistral_detector import MistralDetector
from sglang.srt.function_call.pythonic_detector import PythonicDetector
from sglang.srt.function_call.qwen3_coder_detector import Qwen3CoderDetector
from sglang.srt.function_call.qwen25_detector import Qwen25Detector
from sglang.srt.function_call.step3_detector import Step3Detector
logger = logging.getLogger(__name__)
......@@ -39,6 +40,7 @@ class FunctionCallParser:
"kimi_k2": KimiK2Detector,
"qwen3_coder": Qwen3CoderDetector,
"glm45": Glm4MoeDetector,
"step3": Step3Detector,
}
def __init__(self, tools: List[Tool], tool_call_parser: str):
......
import ast
import json
import logging
import re
from typing import Any, Dict, List
from sglang.srt.entrypoints.openai.protocol import Tool
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
from sglang.srt.function_call.core_types import (
StreamingParseResult,
ToolCallItem,
_GetInfoFunc,
)
from sglang.srt.function_call.ebnf_composer import EBNFComposer
logger = logging.getLogger(__name__)
def get_argument_type(func_name: str, arg_key: str, defined_tools: List[Tool]) -> str:
"""Get the expected type for a function argument from tool schema."""
name2tool = {tool.function.name: tool for tool in defined_tools}
if func_name not in name2tool:
return None
tool = name2tool[func_name]
parameters = tool.function.parameters or {}
properties = parameters.get("properties", {})
if arg_key not in properties:
return None
return properties[arg_key].get("type", None)
def parse_arguments(value: str) -> tuple[Any, bool]:
"""Parse a string value to appropriate type. Returns (parsed_value, success)."""
try:
try:
parsed_value = json.loads(value)
except:
parsed_value = ast.literal_eval(value)
return parsed_value, True
except:
return value, False
class Step3Detector(BaseFormatDetector):
"""
Detector for Step3 model function call format.
The Step3 format uses special Unicode tokens to delimit function calls
with steptml XML format for invocations.
Format Structure:
```
<|tool_calls_begin|>
<|tool_call_begin|>function<|tool_sep|><steptml:invoke name="function_name">
<steptml:parameter name="param1">value1</steptml:parameter>
<steptml:parameter name="param2">value2</steptml:parameter>
</steptml:invoke><|tool_call_end|>
<|tool_calls_end|>
```
"""
def __init__(self):
super().__init__()
self.bot_token = "<|tool_calls_begin|>"
self.eot_token = "<|tool_calls_end|>"
self.tool_call_begin = "<|tool_call_begin|>"
self.tool_call_end = "<|tool_call_end|>"
self.tool_sep = "<|tool_sep|>"
# Regex for parsing steptml invocations
self.invoke_regex = re.compile(
r'<steptml:invoke name="([^"]+)">(.+?)</steptml:invoke>', re.DOTALL
)
self.param_regex = re.compile(
r'<steptml:parameter name="([^"]+)">([^<]*)</steptml:parameter>', re.DOTALL
)
# Streaming state variables
self._in_tool_block: bool = False
self._tool_block_finished: bool = False
self._current_function_name: str = ""
self._current_parameters: Dict[str, Any] = {}
self._in_tool_call: bool = False
self._function_name_sent: bool = False
def has_tool_call(self, text: str) -> bool:
"""Check if the text contains a Step3 format tool call."""
return self.bot_token in text
def _parse_steptml_invoke(
self, text: str, tools: List[Tool] = None
) -> tuple[str, dict]:
"""Parse steptml invoke format to extract function name and parameters."""
invoke_match = self.invoke_regex.search(text)
if not invoke_match:
return None, {}
func_name = invoke_match.group(1)
params_text = invoke_match.group(2)
params = {}
for param_match in self.param_regex.finditer(params_text):
param_name = param_match.group(1)
param_value = param_match.group(2).strip()
# If tools provided, use schema-aware parsing
if tools:
arg_type = get_argument_type(func_name, param_name, tools)
if arg_type and arg_type != "string":
parsed_value, _ = parse_arguments(param_value)
params[param_name] = parsed_value
else:
params[param_name] = param_value
else:
# Fallback to generic parsing if no tools provided
parsed_value, _ = parse_arguments(param_value)
params[param_name] = parsed_value
return func_name, params
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
"""
One-time parsing: Detects and parses tool calls in the provided text.
"""
if self.bot_token not in text:
return StreamingParseResult(normal_text=text, calls=[])
try:
pre_text, rest = text.split(self.bot_token, 1)
# If no end token, return everything as normal text
if self.eot_token not in rest:
return StreamingParseResult(normal_text=text, calls=[])
tool_section, post_text = rest.split(self.eot_token, 1)
# Find all individual tool calls using regex
calls = []
tool_call_pattern = (
f"{re.escape(self.tool_call_begin)}(.*?){re.escape(self.tool_call_end)}"
)
for match in re.finditer(tool_call_pattern, tool_section, re.DOTALL):
call_content = match.group(1)
# Check if it's a function call
if self.tool_sep not in call_content:
continue
type_part, invoke_part = call_content.split(self.tool_sep, 1)
if type_part.strip() != "function":
continue
func_name, params = self._parse_steptml_invoke(invoke_part, tools)
if func_name:
# Use parse_base_json to create the ToolCallItem
action = {"name": func_name, "arguments": params}
calls.extend(self.parse_base_json(action, tools))
# Combine pre and post text
normal_text = pre_text + post_text
return StreamingParseResult(normal_text=normal_text, calls=calls)
except Exception as e:
logger.error(f"Error in detect_and_parse: {e}")
# Return the original text if parsing fails
return StreamingParseResult(normal_text=text)
def parse_streaming_increment(
self, new_text: str, tools: List[Tool]
) -> StreamingParseResult:
"""
Streaming incremental parsing for Step3 format.
"""
self._buffer += new_text
# Build tool indices for validation
if not hasattr(self, "_tool_indices"):
self._tool_indices = self._get_tool_indices(tools)
# If we've finished the tool block, everything is normal text
if self._tool_block_finished:
normal_text = self._buffer
self._buffer = ""
return StreamingParseResult(normal_text=normal_text)
# Check if tool block hasn't started yet
if not self._in_tool_block:
if self.bot_token in self._buffer:
idx = self._buffer.find(self.bot_token)
normal_text = self._buffer[:idx]
self._buffer = self._buffer[idx + len(self.bot_token) :]
self._in_tool_block = True
return StreamingParseResult(normal_text=normal_text)
else:
# Check if we might have a partial bot_token
partial_len = self._ends_with_partial_token(
self._buffer, self.bot_token
)
if partial_len:
return StreamingParseResult() # Wait for more text
else:
normal_text = self._buffer
self._buffer = ""
return StreamingParseResult(normal_text=normal_text)
# We're inside the tool block
calls: List[ToolCallItem] = []
# Check if tool block is ending
if self.eot_token in self._buffer:
idx = self._buffer.find(self.eot_token)
# If we're in the middle of a tool call, we need to handle it
if self._in_tool_call:
# The buffer before eot_token might contain the end of the current tool call
before_eot = self._buffer[:idx]
if self.tool_call_end in before_eot:
# Parse this final tool call
result = self._parse_partial_tool_call(tools)
calls.extend(result.calls)
else:
# Incomplete tool call - log warning
logger.warning("Tool block ended with incomplete tool call")
remaining = self._buffer[idx + len(self.eot_token) :]
self._buffer = ""
self._tool_block_finished = True
# Reset any partial tool call state
self._reset_streaming_state()
return StreamingParseResult(normal_text=remaining, calls=calls)
# Check if we're in a tool call or need to start one
if not self._in_tool_call:
if self.tool_call_begin in self._buffer:
idx = self._buffer.find(self.tool_call_begin)
# Remove any content before tool call begin (shouldn't happen but be safe)
self._buffer = self._buffer[idx + len(self.tool_call_begin) :]
self._in_tool_call = True
self._function_name_sent = False
self._current_function_name = ""
self._current_parameters = {}
# Fall through to parse the partial tool call
else:
# Wait for tool call to begin
return StreamingParseResult()
# Parse partial tool call
if self._in_tool_call:
return self._parse_partial_tool_call(tools)
return StreamingParseResult()
def _parse_partial_tool_call(self, tools: List[Tool]) -> StreamingParseResult:
"""Parse partial tool call for streaming scenarios."""
calls = []
# Check if we have tool_sep (means we're past the type declaration)
if self.tool_sep not in self._buffer:
return StreamingParseResult(calls=calls) # Wait for more text
type_part, invoke_part = self._buffer.split(self.tool_sep, 1)
if type_part.strip() != "function":
# Invalid tool type, skip this tool call
self._reset_streaming_state()
return StreamingParseResult(calls=calls)
# Try to extract function name if not sent yet
if not self._function_name_sent:
name_match = re.search(r'<steptml:invoke name="([^"]+)">', invoke_part)
if name_match:
func_name = name_match.group(1)
# Validate function name
if func_name in self._tool_indices:
self._current_function_name = func_name
self._function_name_sent = True
# Initialize tool tracking
if self.current_tool_id == -1:
self.current_tool_id = 0
# Ensure tracking arrays are large enough
while len(self.prev_tool_call_arr) <= self.current_tool_id:
self.prev_tool_call_arr.append({})
while len(self.streamed_args_for_tool) <= self.current_tool_id:
self.streamed_args_for_tool.append("")
# Store tool call info
self.prev_tool_call_arr[self.current_tool_id] = {
"name": func_name,
"arguments": {},
}
# Send tool name with empty parameters
calls.append(
ToolCallItem(
tool_index=self.current_tool_id,
name=func_name,
parameters="",
)
)
else:
# Invalid function name
logger.warning(f"Invalid function name: {func_name}")
self._reset_streaming_state()
return StreamingParseResult(calls=calls)
else:
# Function name not complete yet
return StreamingParseResult(calls=calls)
# Parse parameters incrementally
if self._function_name_sent:
# Extract all complete parameters
new_params = {}
for param_match in self.param_regex.finditer(invoke_part):
param_name = param_match.group(1)
param_value = param_match.group(2).strip()
# Use schema-aware parsing
arg_type = get_argument_type(
self._current_function_name, param_name, tools
)
if arg_type and arg_type != "string":
parsed_value, _ = parse_arguments(param_value)
new_params[param_name] = parsed_value
else:
new_params[param_name] = param_value
# Check if we have new parameters to stream
if new_params != self._current_parameters:
# Build the JSON content without the closing brace for streaming
if not self._current_parameters:
# First parameters - send opening brace and content
params_content = json.dumps(new_params, ensure_ascii=False)
if len(params_content) > 2: # More than just "{}"
# Send everything except the closing brace
diff = params_content[:-1]
else:
diff = "{"
else:
# Subsequent parameters - calculate the incremental diff
old_json = json.dumps(self._current_parameters, ensure_ascii=False)
new_json = json.dumps(new_params, ensure_ascii=False)
# Remove closing braces for comparison
old_without_brace = old_json[:-1]
new_without_brace = new_json[:-1]
# The new content should extend the old content
if new_without_brace.startswith(old_without_brace):
diff = new_without_brace[len(old_without_brace) :]
else:
# Parameters changed in unexpected way - shouldn't happen in normal streaming
diff = ""
if diff:
calls.append(
ToolCallItem(
tool_index=self.current_tool_id,
parameters=diff,
)
)
self.streamed_args_for_tool[self.current_tool_id] += diff
# Update current state
self._current_parameters = new_params
self.prev_tool_call_arr[self.current_tool_id]["arguments"] = new_params
# Check if tool call is complete
if self.tool_call_end in self._buffer:
# Send closing brace if we've sent any parameters
if self.streamed_args_for_tool[self.current_tool_id]:
calls.append(
ToolCallItem(
tool_index=self.current_tool_id,
parameters="}",
)
)
self.streamed_args_for_tool[self.current_tool_id] += "}"
# Find the end position
end_idx = self._buffer.find(self.tool_call_end)
# Remove the processed tool call from buffer
self._buffer = self._buffer[end_idx + len(self.tool_call_end) :]
# Reset state for next tool call
self._reset_streaming_state()
self.current_tool_id += 1
return StreamingParseResult(calls=calls)
def _reset_streaming_state(self):
"""Reset streaming state for the next tool call"""
self._in_tool_call = False
self._function_name_sent = False
self._current_function_name = ""
self._current_parameters = {}
def supports_structural_tag(self) -> bool:
"""Return True if this detector supports structural tag format."""
return False
def structure_info(self) -> _GetInfoFunc:
raise NotImplementedError()
def build_ebnf(self, tools: List[Tool]) -> str:
"""
Build EBNF grammar for Step3 tool call format.
"""
# Custom call rule for steptml format
call_rule_fmt = (
'"function" "<|tool_sep|>" "<steptml:invoke name=\\"{name}\\">" '
'{arguments_rule} "</steptml:invoke>"'
)
# Custom key-value rule for steptml parameters
key_value_rule_fmt = (
'"<steptml:parameter name=\\"{key}\\">" {valrule} "</steptml:parameter>"'
)
return EBNFComposer.build_ebnf(
tools,
sequence_start_token=self.bot_token,
sequence_end_token=self.eot_token,
individual_call_start_token=self.tool_call_begin,
individual_call_end_token=self.tool_call_end,
tool_call_separator="",
function_format="xml",
call_rule_fmt=call_rule_fmt,
key_value_rule_fmt=key_value_rule_fmt,
key_value_separator="",
)
......@@ -41,6 +41,7 @@ from sglang.srt.configs import (
ExaoneConfig,
KimiVLConfig,
MultiModalityConfig,
Step3VLConfig,
)
from sglang.srt.configs.internvl import InternVLChatConfig
from sglang.srt.connector import create_remote_connector
......@@ -54,6 +55,7 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
MultiModalityConfig.model_type: MultiModalityConfig,
KimiVLConfig.model_type: KimiVLConfig,
InternVLChatConfig.model_type: InternVLChatConfig,
Step3VLConfig.model_type: Step3VLConfig,
}
for name, cls in _CONFIG_REGISTRY.items():
......
......@@ -165,7 +165,7 @@ def process_content_for_template_format(
new_msg["content"] = processed_content_parts
return new_msg
else: # content_format == "string"
elif content_format == "string":
# String format: flatten to text only (for templates like DeepSeek)
text_parts = []
for chunk in msg_dict["content"]:
......@@ -179,3 +179,6 @@ def process_content_for_template_format(
new_msg["content"] = " ".join(text_parts) if text_parts else ""
new_msg = {k: v for k, v in new_msg.items() if v is not None}
return new_msg
else:
raise ValueError(f"Invalid content format: {content_format}")
......@@ -53,7 +53,7 @@ class TemplateManager:
def __init__(self):
self._chat_template_name: Optional[str] = None
self._completion_template_name: Optional[str] = None
self._jinja_template_content_format: Optional[str] = None
self._jinja_template_content_format: Optional[str] = "openai"
@property
def chat_template_name(self) -> Optional[str]:
......@@ -71,19 +71,50 @@ class TemplateManager:
return self._jinja_template_content_format
def load_chat_template(
self, tokenizer_manager, chat_template_arg: str, model_path: str
self, tokenizer_manager, chat_template_arg: Optional[str], model_path: str
) -> None:
"""
Load a chat template from various sources.
Args:
tokenizer_manager: The tokenizer manager instance
chat_template_arg: Template name or file path
chat_template_arg: Template name, file path, or None to auto-detect
model_path: Path to the model
"""
logger.info(f"Loading chat template: {chat_template_arg}")
if chat_template_arg:
self._load_explicit_chat_template(tokenizer_manager, chat_template_arg)
else:
# Try HuggingFace template first
hf_template = self._resolve_hf_chat_template(tokenizer_manager)
if hf_template:
self._jinja_template_content_format = (
detect_jinja_template_content_format(hf_template)
)
logger.info(
f"Using default HuggingFace chat template with detected content format: {self._jinja_template_content_format}"
)
return
# Fallback to SGLang template guessing
self.guess_chat_template_from_model_path(model_path)
# Set default format if no template was found
if self._chat_template_name is None:
self._jinja_template_content_format = "string"
logger.info(
"No chat template found, defaulting to 'string' content format"
)
def _load_explicit_chat_template(
self, tokenizer_manager, chat_template_arg: str
) -> None:
"""Load explicitly specified chat template."""
logger.info(f"Loading chat template from argument: {chat_template_arg}")
if chat_template_exists(chat_template_arg):
self._chat_template_name = chat_template_arg
return
if not chat_template_exists(chat_template_arg):
if not os.path.exists(chat_template_arg):
raise RuntimeError(
f"Chat template {chat_template_arg} is not a built-in template name "
......@@ -94,8 +125,6 @@ class TemplateManager:
self._load_jinja_template(tokenizer_manager, chat_template_arg)
else:
self._load_json_chat_template(chat_template_arg)
else:
self._chat_template_name = chat_template_arg
def guess_chat_template_from_model_path(self, model_path: str) -> None:
"""
......@@ -146,10 +175,7 @@ class TemplateManager:
completion_template: Optional completion template name/path
"""
# Load chat template
if chat_template:
self.load_chat_template(tokenizer_manager, chat_template, model_path)
else:
self.guess_chat_template_from_model_path(model_path)
# Load completion template
if completion_template:
......@@ -166,7 +192,7 @@ class TemplateManager:
chat_template
)
logger.info(
f"Detected chat template content format: {self._jinja_template_content_format}"
f"Detected user specified Jinja chat template with content format: {self._jinja_template_content_format}"
)
def _load_json_chat_template(self, template_path: str) -> None:
......@@ -224,3 +250,20 @@ class TemplateManager:
override=True,
)
self._completion_template_name = template["name"]
def _resolve_hf_chat_template(self, tokenizer_manager) -> Optional[str]:
"""
Resolve HuggingFace chat template.
Returns the chat template string if found, None otherwise.
"""
tokenizer = tokenizer_manager.tokenizer
# Try to get AutoTokenizer chat template
try:
return tokenizer.get_chat_template()
except Exception as e:
logger.debug(f"Error getting chat template via get_chat_template(): {e}")
logger.debug("No HuggingFace chat template found")
return None
This diff is collapsed.
......@@ -176,6 +176,8 @@ class BaseMultimodalProcessor(ABC):
"image_grid_hws": Modality.IMAGE,
"aspect_ratio_ids": Modality.IMAGE,
"aspect_ratio_mask": Modality.IMAGE,
"num_patches": Modality.IMAGE,
"patch_pixel_values": Modality.IMAGE,
# Audio-related attributes
"audio_features": Modality.AUDIO,
"audio_feature_lens": Modality.AUDIO,
......
import math
import re
from itertools import product
from typing import List, Literal, Optional, TypedDict, Union
import numpy as np
import torch
from PIL import Image
from torchvision import transforms
from torchvision.transforms import InterpolationMode
from transformers import BatchFeature, TensorType
from sglang.srt.models.step3_vl import Step3VLForConditionalGeneration
from sglang.srt.multimodal.processors.base_processor import (
BaseMultimodalProcessor as SGLangBaseProcessor,
)
from sglang.srt.multimodal.processors.base_processor import MultimodalSpecialTokens
ImageWithPatches = tuple[Image.Image, list[Image.Image], list[int] | None]
class GPUToTensor(torch.nn.Module):
def forward(self, raw_image: Union[np.ndarray, Image.Image]) -> torch.Tensor:
if isinstance(raw_image, Image.Image):
return transforms.ToTensor()(raw_image)
if raw_image.ndim == 2:
raw_image = raw_image[:, :, None].repeat(3, -1)
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
image_tensor = torch.from_numpy(raw_image).to(device)
image_tensor = torch.permute(image_tensor, (2, 0, 1)).contiguous()
if image_tensor.dtype == torch.uint8:
image_tensor = image_tensor.to(torch.float32).div(255)
return image_tensor
class Step3VisionProcessor:
def __init__(self, size, interpolation_mode="bicubic", patch_size=None):
mean = [0.48145466, 0.4578275, 0.40821073]
std = [0.26862954, 0.26130258, 0.27577711]
patch_size = patch_size if patch_size is not None else size
self.transform = transforms.Compose(
[
GPUToTensor(),
transforms.Normalize(mean, std),
transforms.Resize(
(size, size),
interpolation=(
InterpolationMode.BICUBIC
if interpolation_mode == "bicubic"
else InterpolationMode.BILINEAR
),
antialias=True,
),
]
)
self.patch_transform = (
transforms.Compose(
[
GPUToTensor(),
transforms.Normalize(mean, std),
transforms.Resize(
(patch_size, patch_size),
interpolation=(
InterpolationMode.BICUBIC
if interpolation_mode == "bicubic"
else InterpolationMode.BILINEAR
),
antialias=True,
),
]
)
if patch_size is not None
else None
)
def __call__(self, image, is_patch=False):
if is_patch:
return {"pixel_values": self.patch_transform(image).unsqueeze(0)}
else:
return {"pixel_values": self.transform(image).unsqueeze(0)}
class ImagePatcher:
def determine_window_size(self, long: int, short: int) -> int:
if long <= 728:
return short if long / short > 1.5 else 0
return min(short, 504) if long / short > 4 else 504
def slide_window(
self,
width: int,
height: int,
sizes: list[tuple[int, int]],
steps: list[tuple[int, int]],
img_rate_thr: float = 0.6,
) -> tuple[list[tuple[int, int, int, int]], tuple[int, int]]:
assert 1 >= img_rate_thr >= 0, "The `in_rate_thr` should lie in 0~1"
windows = []
# Sliding windows.
for size, step in zip(sizes, steps):
size_w, size_h = size
step_w, step_h = step
x_num = 1 if width <= size_w else math.ceil((width - size_w) / step_w + 1)
x_start = [step_w * i for i in range(x_num)]
if len(x_start) > 1 and x_start[-1] + size_w > width:
x_start[-1] = width - size_w
y_num = 1 if height <= size_h else math.ceil((height - size_h) / step_h + 1)
y_start = [step_h * i for i in range(y_num)]
if len(y_start) > 1 and y_start[-1] + size_h > height:
y_start[-1] = height - size_h
start = np.array(list(product(y_start, x_start)), dtype=int)
start[:, [0, 1]] = start[:, [1, 0]]
windows.append(np.concatenate([start, start + size], axis=1))
windows = np.concatenate(windows, axis=0)
return [
(int(box[0]), int(box[1]), int(box[2] - box[0]), int(box[3] - box[1]))
for box in windows
], (x_num, y_num)
def square_pad(self, img: Image.Image) -> Image.Image:
w, h = img.size
if w == h:
return img
size = max(w, h)
padded = Image.new(img.mode, (size, size), 0)
padded.paste(img, (0, 0))
return padded
def get_image_size_for_padding(
self, img_width: int, img_height: int
) -> tuple[int, int]:
ratio = img_width / img_height
if min(img_height, img_width) < 32 and (ratio > 4 or ratio < 1 / 4):
new_size = max(img_height, img_width)
return new_size, new_size
return img_width, img_height
def get_image_size_for_preprocess(
self, img_width: int, img_height: int
) -> tuple[int, int]:
if max(img_height, img_width) > 3024:
scale_factor = 3024 / max(img_height, img_width)
img_width = int(img_width * scale_factor)
img_height = int(img_height * scale_factor)
return img_width, img_height
else:
return img_width, img_height
def get_image_size_for_crop(
self, img_width: int, img_height: int, window_size: int
):
w_ratio = img_width / window_size
h_ratio = img_height / window_size
if w_ratio < 1:
width_new = img_width
else:
decimal_w = w_ratio - img_width // window_size
w_ratio = int(w_ratio) + 1 if decimal_w > 0.2 else int(w_ratio)
width_new = window_size * w_ratio
if h_ratio < 1:
height_new = img_height
else:
decimal_h = h_ratio - img_height // window_size
h_ratio = int(h_ratio) + 1 if decimal_h > 0.2 else int(h_ratio)
height_new = window_size * h_ratio
return int(width_new), int(height_new)
def patch_crop(self, img: Image.Image, i: int, j: int, th: int, tw: int):
target = img.crop((j, i, j + tw, i + th))
return target
def get_num_patches(self, img_width: int, img_height: int) -> tuple[int, int]:
img_width, img_height = self.get_image_size_for_padding(img_width, img_height)
img_width, img_height = self.get_image_size_for_preprocess(
img_width, img_height
)
window_size = self.determine_window_size(
max(img_height, img_width), min(img_height, img_width)
)
if window_size == 0:
return 0, 0
else:
img_width, img_height = self.get_image_size_for_crop(
img_width, img_height, window_size
)
center_list, (x_num, y_num) = self.slide_window(
img_width,
img_height,
[(window_size, window_size)],
[(window_size, window_size)],
)
full_rows = (len(center_list) - 1) // x_num + 1
if len(center_list) > 0 and len(center_list) % x_num == 0:
full_rows -= 1
return len(center_list), full_rows
def __call__(
self, img: Image.Image
) -> tuple[Image.Image, list[Image.Image], list[bool] | None]:
img_width, img_height = img.size
new_img_width, new_img_height = self.get_image_size_for_padding(
img_width, img_height
)
if new_img_width != img_width or new_img_height != img_height:
img = self.square_pad(img)
img_width, img_height = img.size
new_img_width, new_img_height = self.get_image_size_for_preprocess(
img_width, img_height
)
img = img.resize((new_img_width, new_img_height), Image.Resampling.BILINEAR)
window_size = self.determine_window_size(
max(new_img_height, new_img_width), min(new_img_height, new_img_width)
)
if window_size == 0:
return img, [], None
else:
new_img_width, new_img_height = self.get_image_size_for_crop(
new_img_width, new_img_height, window_size
)
if (new_img_width, new_img_height) != (img_width, img_height):
img_for_crop = img.resize(
(new_img_width, new_img_height), Image.Resampling.BILINEAR
)
else:
img_for_crop = img
patches = []
newlines = []
center_list, (x_num, y_num) = self.slide_window(
new_img_width,
new_img_height,
[(window_size, window_size)],
[(window_size, window_size)],
)
for patch_id, center_lf_point in enumerate(center_list):
x, y, patch_w, patch_h = center_lf_point
big_patch = self.patch_crop(img_for_crop, y, x, patch_h, patch_w)
patches.append(big_patch)
if (patch_id + 1) % x_num == 0:
newlines.append(patch_id)
if newlines and newlines[-1] == len(patches) - 1:
newlines.pop()
return (
img,
patches,
(
[i in newlines for i in range(len(patches))]
if len(patches) > 0
else None
),
)
class Step3VLProcessor:
def __init__(
self,
config,
tokenizer,
) -> None:
super().__init__()
self.config = config
self.tokenizer = tokenizer
self.image_size = 728
self.patch_size = 504
self.image_preprocessor = Step3VisionProcessor(
self.image_size, "bilinear", self.patch_size
)
self.num_image_feature_size = 169
self.num_patch_feature_size = 81
self.image_token = "<im_patch>"
self.image_feature_placeholder = self.image_token * self.num_image_feature_size
self.patch_feature_placeholder = self.image_token * self.num_patch_feature_size
self.patcher = ImagePatcher()
@property
def image_token_id(self) -> int:
return self.tokenizer.get_vocab()[self.image_token]
def get_num_image_tokens(self, img_width: int, img_height: int) -> int:
num_patches, num_newlines = self.patcher.get_num_patches(img_width, img_height)
return (
num_patches * (self.num_patch_feature_size + 2)
+ self.num_image_feature_size
+ 2
+ num_newlines
)
def _split_images(self, images: list[Image.Image]) -> list[ImageWithPatches]:
result = []
for img in images:
result.append(self.patcher(img))
return result
def _convert_images_to_pixel_values(
self,
images: list[Image.Image],
is_patch: bool = False,
) -> list[torch.Tensor]:
return [
self.image_preprocessor(img, is_patch=is_patch)["pixel_values"]
for img in images
]
def _get_patch_repl(
self,
num_patches: int,
patch_newline_mask: list[bool] | None,
) -> tuple[str, list[int]]:
text = ""
token_ids = []
for i in range(num_patches):
assert len(patch_newline_mask) == num_patches
text += f"<patch_start>{self.patch_feature_placeholder}<patch_end>"
token_ids.extend(
[self.tokenizer.convert_tokens_to_ids("<patch_start>")]
+ [self.image_token_id] * self.num_patch_feature_size
+ [self.tokenizer.convert_tokens_to_ids("<patch_end>")]
)
if patch_newline_mask and patch_newline_mask[i]:
text += "<patch_newline>"
token_ids.append(
self.tokenizer.convert_tokens_to_ids("<patch_newline>")
)
return text, token_ids
def _get_image_repl(
self,
num_images: int,
) -> tuple[str, list[int]]:
text = f"<im_start>{self.image_feature_placeholder}<im_end>"
token_ids = (
[self.tokenizer.convert_tokens_to_ids("<im_start>")]
+ [self.image_token_id] * self.num_image_feature_size
+ [self.tokenizer.convert_tokens_to_ids("<im_end>")]
)
return text * num_images, token_ids * num_images
def _get_image_repl_features(
self,
num_images: int,
num_patches: int,
patch_new_line_idx: Optional[list[bool]],
) -> tuple[str, list[int]]:
if num_patches > 0:
patch_repl, patch_repl_ids = self._get_patch_repl(
num_patches, patch_new_line_idx
)
else:
patch_repl = ""
patch_repl_ids = []
image_repl, image_repl_ids = self._get_image_repl(num_images)
return patch_repl + image_repl, patch_repl_ids + image_repl_ids
def replace_placeholder(self, text: str, placeholder: str, repls: list[str]) -> str:
parts = text.split(placeholder)
if len(parts) - 1 != len(repls):
raise ValueError(
"The number of placeholders does not match the number of replacements." # noqa: E501
)
result = [parts[0]]
for i, repl in enumerate(repls):
result.append(repl)
result.append(parts[i + 1])
return "".join(result)
def __call__(
self,
text: Optional[Union[str, list[str]]] = None,
images: Optional[Union[Image.Image, list[Image.Image]]] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
*args,
**kwargs,
) -> BatchFeature:
if text is None:
text = []
if not isinstance(text, list):
text = [text]
if images is None:
images = []
if not isinstance(images, list):
images = [images]
if len(images) == 0:
image_inputs = {}
text_inputs = self.tokenizer(text)
else:
splitted_images_data = self._split_images(images)
pixel_values_lst = []
patch_pixel_values_lst = []
patch_newline_mask_lst = []
image_repl_str_lst = []
image_repl_ids_lst = []
num_patches = []
for (
raw_img,
img_patches,
patch_newline_mask,
) in splitted_images_data: # noqa: E501
pixel_values_lst.extend(self._convert_images_to_pixel_values([raw_img]))
if len(img_patches) > 0:
patch_pixel_values_lst.extend(
self._convert_images_to_pixel_values(img_patches, is_patch=True)
)
num_patches.append(len(img_patches))
image_repl_str, image_repl_ids = self._get_image_repl_features(
1, len(img_patches), patch_newline_mask
)
image_repl_str_lst.append(image_repl_str)
image_repl_ids_lst.extend(image_repl_ids)
if patch_newline_mask is not None:
patch_newline_mask_lst.extend(patch_newline_mask)
image_inputs = {
"pixel_values": torch.cat(pixel_values_lst),
"num_patches": num_patches,
}
if patch_pixel_values_lst:
image_inputs["patch_pixel_values"] = torch.cat(patch_pixel_values_lst)
if patch_newline_mask_lst:
image_inputs["patch_newline_mask"] = torch.tensor(
patch_newline_mask_lst, dtype=torch.bool
)
text = [
self.replace_placeholder(t, self.image_token, image_repl_str_lst)
for t in text
]
text_inputs = self.tokenizer(text)
return BatchFeature(
{
**text_inputs,
**image_inputs,
},
tensor_type=return_tensors,
)
################################################
class Step3VLImageProcessor(SGLangBaseProcessor):
models = [Step3VLForConditionalGeneration]
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
# TODO, check _processor is tokenizer or processor.
processor = Step3VLProcessor(hf_config, _processor)
super().__init__(hf_config, server_args, processor, *args, **kwargs)
self.IM_TOKEN_ID = 128001
self.mm_tokens = MultimodalSpecialTokens(
image_token="<im_patch>",
image_token_id=128001,
image_token_regex=re.compile(r"(?:<im_patch>)"),
).build(_processor)
mean = [0.48145466, 0.4578275, 0.40821073]
std = [0.26862954, 0.26130258, 0.27577711]
def preprocess(self, image):
return {"pixel_values": self.transform(image).unsqueeze(0)}
def __call__(self, image):
return self.preprocess(image)
async def process_mm_data_async(
self,
image_data: List[Union[str, bytes]],
input_text: str | List[int],
request_obj,
*args,
**kwargs,
):
base_output = self.load_mm_data(
prompt=input_text,
image_data=image_data,
video_data=request_obj.video_data,
multimodal_tokens=self.mm_tokens,
)
mm_items, input_ids, ret = self.process_and_combine_mm_data(
base_output, self.mm_tokens
)
return {
"input_ids": input_ids.tolist(),
"mm_items": mm_items,
"im_token_id": self.mm_tokens.image_token_id,
}
......@@ -105,7 +105,7 @@ class BaseReasoningFormatDetector:
# If we're not in a reasoning block return as normal text
if not self._in_reasoning:
self._buffer = ""
return StreamingParseResult(normal_text=new_text)
return StreamingParseResult(normal_text=current_text)
return StreamingParseResult()
......@@ -233,6 +233,7 @@ class ReasoningParser:
"qwen3-thinking": Qwen3ThinkingDetector,
"glm45": Qwen3Detector,
"kimi": KimiDetector,
"step3": DeepSeekR1Detector,
}
def __init__(self, model_type: Optional[str] = None, stream_reasoning: bool = True):
......
......@@ -1117,9 +1117,10 @@ class ServerArgs:
"kimi_k2",
"qwen3_coder",
"glm45",
"step3",
],
default=ServerArgs.tool_call_parser,
help="Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', 'llama3', 'deepseekv3', 'pythonic', 'kimi_k2', and 'qwen3_coder'.",
help="Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', 'llama3', 'deepseekv3', 'pythonic', 'kimi_k2', 'qwen3_coder', 'glm45', and 'step3'.",
)
# Data parallelism
......
......@@ -493,5 +493,117 @@ class TestIntegrationScenarios(CustomTestCase):
self.assertIn("final answer", all_normal)
class TestBufferLossBugFix(CustomTestCase):
"""Test cases for the buffer loss bug fix in parse_streaming_increment."""
def test_partial_end_tag_buffer_loss_bug(self):
"""
Test the bug where partial end tag fragments are lost when followed by normal text.
Bug scenario:
1. _in_reasoning is False
2. new_text is "</" (part of closing thinking tag)
3. Fragment is stored in buffer and empty string is returned
4. Next step: new_text is "answer", _in_reasoning still False
5. Buffer is cleared and "answer" is returned directly
6. The "</" from previous step is lost
This test verifies the fix where line 108 was changed from:
return StreamingParseResult(normal_text=new_text)
to:
return StreamingParseResult(normal_text=current_text)
"""
detector = BaseReasoningFormatDetector("<think>", "</think>")
# Step 1: Send partial end tag when not in reasoning mode
# This should be buffered since it could be start of "</think>"
result1 = detector.parse_streaming_increment("</")
self.assertEqual(result1.normal_text, "")
self.assertEqual(result1.reasoning_text, "")
# Step 2: Send normal text that doesn't complete the end tag
# Before fix: would return only "answer", losing the "</"
# After fix: should return the complete buffered content "</answer"
result2 = detector.parse_streaming_increment("answer")
self.assertEqual(result2.normal_text, "</answer")
self.assertEqual(result2.reasoning_text, "")
def test_partial_start_tag_buffer_preservation(self):
"""
Test that partial start tag fragments are properly preserved.
"""
detector = BaseReasoningFormatDetector("<think>", "</think>")
# Send partial start tag
result1 = detector.parse_streaming_increment("<th")
self.assertEqual(result1.normal_text, "")
self.assertEqual(result1.reasoning_text, "")
# Complete with non-matching text
result2 = detector.parse_streaming_increment("is is text")
self.assertEqual(result2.normal_text, "<this is text")
self.assertEqual(result2.reasoning_text, "")
def test_partial_end_tag_in_reasoning_mode(self):
"""
Test partial end tag handling when already in reasoning mode.
"""
detector = BaseReasoningFormatDetector("<think>", "</think>")
# Enter reasoning mode
detector.parse_streaming_increment("<think>")
detector.parse_streaming_increment("some reasoning")
# Send partial end tag
result1 = detector.parse_streaming_increment("</")
self.assertEqual(result1.normal_text, "")
self.assertEqual(result1.reasoning_text, "")
# Complete the end tag with normal text
result2 = detector.parse_streaming_increment("think>normal text")
self.assertEqual(result2.normal_text, "normal text")
# The reasoning text should be empty since buffer was cleared when end tag was processed
self.assertEqual(result2.reasoning_text, "")
def test_multiple_partial_fragments(self):
"""
Test handling of multiple partial fragments that don't match any tokens.
"""
detector = BaseReasoningFormatDetector("<think>", "</think>")
# Send multiple partial fragments
result1 = detector.parse_streaming_increment("<")
self.assertEqual(result1.normal_text, "")
self.assertEqual(result1.reasoning_text, "")
result2 = detector.parse_streaming_increment("/")
self.assertEqual(result2.normal_text, "")
self.assertEqual(result2.reasoning_text, "")
result3 = detector.parse_streaming_increment("random>")
self.assertEqual(result3.normal_text, "</random>")
self.assertEqual(result3.reasoning_text, "")
def test_edge_case_exact_token_match(self):
"""
Test edge case where buffer content exactly matches a token.
"""
detector = BaseReasoningFormatDetector("<think>", "</think>")
# Build up the exact start token character by character
detector.parse_streaming_increment("<")
detector.parse_streaming_increment("t")
detector.parse_streaming_increment("h")
detector.parse_streaming_increment("i")
detector.parse_streaming_increment("n")
result = detector.parse_streaming_increment("k>")
# Should enter reasoning mode
self.assertEqual(result.normal_text, "")
self.assertEqual(result.reasoning_text, "")
self.assertTrue(detector._in_reasoning)
self.assertTrue(detector.stripped_think_start)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment