Unverified Commit 51c38163 authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

model: support Step3V (#8583)


Signed-off-by: default avatarXinyuan Tong <justinning0323@outlook.com>
Co-authored-by: default avatarnnnobody-code <nnnobody@foxmail.com>
Co-authored-by: default avatarispobock <ispobaoke@gmail.com>
Co-authored-by: default avatarQiaolin-Yu <qy254@cornell.edu>
Co-authored-by: default avatarQiaolin-Yu <liin1211@outlook.com>
Co-authored-by: default avatarXinyuan Tong <justinning0323@outlook.com>
Co-authored-by: default avatarXinyuan Tong <115166877+JustinTong0323@users.noreply.github.com>
parent 09f1a247
......@@ -148,7 +148,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s
| `--file-storage-path` | The path of the file storage in backend. | sglang_storage |
| `--enable-cache-report` | Return number of cached tokens in usage.prompt_tokens_details for each openai request. | False |
| `--reasoning-parser` | Specify the parser for reasoning models, supported parsers are: {list(ReasoningParser.DetectorMap.keys())}. | None |
| `--tool-call-parser` | Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', 'llama3', 'deepseekv3', 'pythonic', and 'kimi_k2'. | None |
| `--tool-call-parser` | Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', 'llama3', 'deepseekv3', 'pythonic', 'kimi_k2', 'qwen3_coder', 'glm45', and 'step3'. | None |
## Data parallelism
......
......@@ -5,6 +5,11 @@ from sglang.srt.configs.exaone import ExaoneConfig
from sglang.srt.configs.janus_pro import MultiModalityConfig
from sglang.srt.configs.kimi_vl import KimiVLConfig
from sglang.srt.configs.kimi_vl_moonvit import MoonViTConfig
from sglang.srt.configs.step3_vl import (
Step3TextConfig,
Step3VisionEncoderConfig,
Step3VLConfig,
)
__all__ = [
"ExaoneConfig",
......@@ -14,4 +19,7 @@ __all__ = [
"MultiModalityConfig",
"KimiVLConfig",
"MoonViTConfig",
"Step3VLConfig",
"Step3TextConfig",
"Step3VisionEncoderConfig",
]
......@@ -335,6 +335,8 @@ class ModelConfig:
"num_key_value_heads",
# For ChatGLM:
"multi_query_group_num",
# For Step3
"num_attention_groups",
]
for attr in attributes:
num_kv_heads = getattr(self.hf_text_config, attr, None)
......@@ -644,6 +646,7 @@ multimodal_model_archs = [
"InternS1ForConditionalGeneration",
"Phi4MMForCausalLM",
"VILAForConditionalGeneration",
"Step3VLForConditionalGeneration",
]
......
from typing import Any, Optional, Union
from transformers.configuration_utils import PretrainedConfig
class Step3VisionEncoderConfig(PretrainedConfig):
model_type = "step3_vision_encoder"
def __init__(
self,
hidden_size=1792,
intermediate_size=3072,
output_hidden_size=4096,
num_hidden_layers=63,
num_attention_heads=16,
num_channels=3,
image_size=728,
patch_size=14,
hidden_act="quick_gelu",
layer_norm_eps=1e-5,
**kwargs,
):
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.output_hidden_size = output_hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_channels = num_channels
self.patch_size = patch_size
self.image_size = image_size
self.layer_norm_eps = layer_norm_eps
self.hidden_act = hidden_act
super().__init__(**kwargs)
class Step3TextConfig(PretrainedConfig):
model_type = "step3_text"
architectures = ["Step3TextForCausalLM"]
def __init__(
self,
hidden_size: int = 7168,
intermediate_size: int = 18432,
num_attention_heads: int = 64,
num_attention_groups: int = 1,
num_hidden_layers: int = 61,
max_seq_len: int = 65536,
vocab_size: int = 128815,
rms_norm_eps: float = 1e-5,
moe_intermediate_size: int = 5120,
moe_num_experts: int = 48,
moe_top_k: int = 3,
rope_theta: float = 500000,
rope_scaling: Optional[dict[str, Any]] = None,
max_position_embedding: int = 65536,
share_expert_dim: int = 5120,
share_q_dim: int = 2048,
head_dim: int = 256,
norm_expert_weight: bool = False,
moe_layers_enum: tuple[int] = (
4,
5,
6,
7,
8,
9,
10,
11,
12,
13,
14,
15,
16,
17,
18,
19,
20,
21,
22,
23,
24,
25,
26,
27,
28,
29,
30,
31,
32,
33,
34,
35,
36,
37,
38,
39,
40,
41,
42,
43,
44,
45,
46,
47,
48,
49,
50,
51,
52,
53,
54,
55,
56,
57,
58,
59,
),
**kwargs,
) -> None:
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_attention_heads = num_attention_heads
self.num_attention_groups = num_attention_groups
self.num_hidden_layers = num_hidden_layers
self.max_seq_len = max_seq_len
self.vocab_size = vocab_size
self.rms_norm_eps = rms_norm_eps
self.moe_intermediate_size = moe_intermediate_size
self.moe_num_experts = moe_num_experts
self.moe_top_k = moe_top_k
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.max_position_embedding = max_position_embedding
self.share_expert_dim = share_expert_dim
self.share_q_dim = share_q_dim
self.head_dim = head_dim
self.norm_expert_weight = norm_expert_weight
self.moe_layers_enum = moe_layers_enum
super().__init__(**kwargs)
class Step3VLConfig(PretrainedConfig):
model_type = "step3_vl"
def __init__(
self,
vision_config: Optional[Union[dict, Step3VisionEncoderConfig]] = None,
text_config: Optional[Union[dict, Step3TextConfig]] = None,
understand_projector_stride: int = 1,
projector_bias: bool = True,
image_token_id: int = 128001,
**kwargs,
) -> None:
if vision_config is None:
vision_config = Step3VisionEncoderConfig()
elif isinstance(vision_config, dict):
vision_config = Step3VisionEncoderConfig(**vision_config)
self.vision_config = vision_config
if text_config is None:
text_config = Step3TextConfig()
elif isinstance(text_config, dict):
text_config = Step3TextConfig(**text_config)
self.text_config = text_config
self.understand_projector_stride = understand_projector_stride
self.projector_bias = projector_bias
self.hidden_size = text_config.hidden_size
self.image_token_id = image_token_id
super().__init__(**kwargs)
......@@ -994,6 +994,23 @@ register_conv_template(
)
)
register_conv_template(
Conversation(
name="step3-vl",
system_message="<|begin▁of▁sentence|>You are a helpful assistant",
system_template="{system_message}\n",
roles=(
"<|BOT|>user\n",
"<|BOT|>assistant\n<think>\n",
),
sep="<|EOT|>",
sep_style=SeparatorStyle.NO_COLON_SINGLE,
stop_str="<|EOT|>",
image_token="<im_patch>",
# add_bos=True,
)
)
@register_conv_template_matching_function
def match_internvl(model_path: str):
......@@ -1103,3 +1120,9 @@ def match_vila(model_path: str):
def match_mimo_vl(model_path: str):
if re.search(r"mimo.*vl", model_path, re.IGNORECASE):
return "mimo-vl"
# @register_conv_template_matching_function
# def match_step3(model_path: str):
# if re.search(r"step3", model_path, re.IGNORECASE):
# return "step3-vl"
......@@ -17,6 +17,7 @@ from sglang.srt.function_call.mistral_detector import MistralDetector
from sglang.srt.function_call.pythonic_detector import PythonicDetector
from sglang.srt.function_call.qwen3_coder_detector import Qwen3CoderDetector
from sglang.srt.function_call.qwen25_detector import Qwen25Detector
from sglang.srt.function_call.step3_detector import Step3Detector
logger = logging.getLogger(__name__)
......@@ -39,6 +40,7 @@ class FunctionCallParser:
"kimi_k2": KimiK2Detector,
"qwen3_coder": Qwen3CoderDetector,
"glm45": Glm4MoeDetector,
"step3": Step3Detector,
}
def __init__(self, tools: List[Tool], tool_call_parser: str):
......
import ast
import json
import logging
import re
from typing import Any, Dict, List
from sglang.srt.entrypoints.openai.protocol import Tool
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
from sglang.srt.function_call.core_types import (
StreamingParseResult,
ToolCallItem,
_GetInfoFunc,
)
from sglang.srt.function_call.ebnf_composer import EBNFComposer
logger = logging.getLogger(__name__)
def get_argument_type(func_name: str, arg_key: str, defined_tools: List[Tool]) -> str:
"""Get the expected type for a function argument from tool schema."""
name2tool = {tool.function.name: tool for tool in defined_tools}
if func_name not in name2tool:
return None
tool = name2tool[func_name]
parameters = tool.function.parameters or {}
properties = parameters.get("properties", {})
if arg_key not in properties:
return None
return properties[arg_key].get("type", None)
def parse_arguments(value: str) -> tuple[Any, bool]:
"""Parse a string value to appropriate type. Returns (parsed_value, success)."""
try:
try:
parsed_value = json.loads(value)
except:
parsed_value = ast.literal_eval(value)
return parsed_value, True
except:
return value, False
class Step3Detector(BaseFormatDetector):
"""
Detector for Step3 model function call format.
The Step3 format uses special Unicode tokens to delimit function calls
with steptml XML format for invocations.
Format Structure:
```
<|tool_calls_begin|>
<|tool_call_begin|>function<|tool_sep|><steptml:invoke name="function_name">
<steptml:parameter name="param1">value1</steptml:parameter>
<steptml:parameter name="param2">value2</steptml:parameter>
</steptml:invoke><|tool_call_end|>
<|tool_calls_end|>
```
"""
def __init__(self):
super().__init__()
self.bot_token = "<|tool_calls_begin|>"
self.eot_token = "<|tool_calls_end|>"
self.tool_call_begin = "<|tool_call_begin|>"
self.tool_call_end = "<|tool_call_end|>"
self.tool_sep = "<|tool_sep|>"
# Regex for parsing steptml invocations
self.invoke_regex = re.compile(
r'<steptml:invoke name="([^"]+)">(.+?)</steptml:invoke>', re.DOTALL
)
self.param_regex = re.compile(
r'<steptml:parameter name="([^"]+)">([^<]*)</steptml:parameter>', re.DOTALL
)
# Streaming state variables
self._in_tool_block: bool = False
self._tool_block_finished: bool = False
self._current_function_name: str = ""
self._current_parameters: Dict[str, Any] = {}
self._in_tool_call: bool = False
self._function_name_sent: bool = False
def has_tool_call(self, text: str) -> bool:
"""Check if the text contains a Step3 format tool call."""
return self.bot_token in text
def _parse_steptml_invoke(
self, text: str, tools: List[Tool] = None
) -> tuple[str, dict]:
"""Parse steptml invoke format to extract function name and parameters."""
invoke_match = self.invoke_regex.search(text)
if not invoke_match:
return None, {}
func_name = invoke_match.group(1)
params_text = invoke_match.group(2)
params = {}
for param_match in self.param_regex.finditer(params_text):
param_name = param_match.group(1)
param_value = param_match.group(2).strip()
# If tools provided, use schema-aware parsing
if tools:
arg_type = get_argument_type(func_name, param_name, tools)
if arg_type and arg_type != "string":
parsed_value, _ = parse_arguments(param_value)
params[param_name] = parsed_value
else:
params[param_name] = param_value
else:
# Fallback to generic parsing if no tools provided
parsed_value, _ = parse_arguments(param_value)
params[param_name] = parsed_value
return func_name, params
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
"""
One-time parsing: Detects and parses tool calls in the provided text.
"""
if self.bot_token not in text:
return StreamingParseResult(normal_text=text, calls=[])
try:
pre_text, rest = text.split(self.bot_token, 1)
# If no end token, return everything as normal text
if self.eot_token not in rest:
return StreamingParseResult(normal_text=text, calls=[])
tool_section, post_text = rest.split(self.eot_token, 1)
# Find all individual tool calls using regex
calls = []
tool_call_pattern = (
f"{re.escape(self.tool_call_begin)}(.*?){re.escape(self.tool_call_end)}"
)
for match in re.finditer(tool_call_pattern, tool_section, re.DOTALL):
call_content = match.group(1)
# Check if it's a function call
if self.tool_sep not in call_content:
continue
type_part, invoke_part = call_content.split(self.tool_sep, 1)
if type_part.strip() != "function":
continue
func_name, params = self._parse_steptml_invoke(invoke_part, tools)
if func_name:
# Use parse_base_json to create the ToolCallItem
action = {"name": func_name, "arguments": params}
calls.extend(self.parse_base_json(action, tools))
# Combine pre and post text
normal_text = pre_text + post_text
return StreamingParseResult(normal_text=normal_text, calls=calls)
except Exception as e:
logger.error(f"Error in detect_and_parse: {e}")
# Return the original text if parsing fails
return StreamingParseResult(normal_text=text)
def parse_streaming_increment(
self, new_text: str, tools: List[Tool]
) -> StreamingParseResult:
"""
Streaming incremental parsing for Step3 format.
"""
self._buffer += new_text
# Build tool indices for validation
if not hasattr(self, "_tool_indices"):
self._tool_indices = self._get_tool_indices(tools)
# If we've finished the tool block, everything is normal text
if self._tool_block_finished:
normal_text = self._buffer
self._buffer = ""
return StreamingParseResult(normal_text=normal_text)
# Check if tool block hasn't started yet
if not self._in_tool_block:
if self.bot_token in self._buffer:
idx = self._buffer.find(self.bot_token)
normal_text = self._buffer[:idx]
self._buffer = self._buffer[idx + len(self.bot_token) :]
self._in_tool_block = True
return StreamingParseResult(normal_text=normal_text)
else:
# Check if we might have a partial bot_token
partial_len = self._ends_with_partial_token(
self._buffer, self.bot_token
)
if partial_len:
return StreamingParseResult() # Wait for more text
else:
normal_text = self._buffer
self._buffer = ""
return StreamingParseResult(normal_text=normal_text)
# We're inside the tool block
calls: List[ToolCallItem] = []
# Check if tool block is ending
if self.eot_token in self._buffer:
idx = self._buffer.find(self.eot_token)
# If we're in the middle of a tool call, we need to handle it
if self._in_tool_call:
# The buffer before eot_token might contain the end of the current tool call
before_eot = self._buffer[:idx]
if self.tool_call_end in before_eot:
# Parse this final tool call
result = self._parse_partial_tool_call(tools)
calls.extend(result.calls)
else:
# Incomplete tool call - log warning
logger.warning("Tool block ended with incomplete tool call")
remaining = self._buffer[idx + len(self.eot_token) :]
self._buffer = ""
self._tool_block_finished = True
# Reset any partial tool call state
self._reset_streaming_state()
return StreamingParseResult(normal_text=remaining, calls=calls)
# Check if we're in a tool call or need to start one
if not self._in_tool_call:
if self.tool_call_begin in self._buffer:
idx = self._buffer.find(self.tool_call_begin)
# Remove any content before tool call begin (shouldn't happen but be safe)
self._buffer = self._buffer[idx + len(self.tool_call_begin) :]
self._in_tool_call = True
self._function_name_sent = False
self._current_function_name = ""
self._current_parameters = {}
# Fall through to parse the partial tool call
else:
# Wait for tool call to begin
return StreamingParseResult()
# Parse partial tool call
if self._in_tool_call:
return self._parse_partial_tool_call(tools)
return StreamingParseResult()
def _parse_partial_tool_call(self, tools: List[Tool]) -> StreamingParseResult:
"""Parse partial tool call for streaming scenarios."""
calls = []
# Check if we have tool_sep (means we're past the type declaration)
if self.tool_sep not in self._buffer:
return StreamingParseResult(calls=calls) # Wait for more text
type_part, invoke_part = self._buffer.split(self.tool_sep, 1)
if type_part.strip() != "function":
# Invalid tool type, skip this tool call
self._reset_streaming_state()
return StreamingParseResult(calls=calls)
# Try to extract function name if not sent yet
if not self._function_name_sent:
name_match = re.search(r'<steptml:invoke name="([^"]+)">', invoke_part)
if name_match:
func_name = name_match.group(1)
# Validate function name
if func_name in self._tool_indices:
self._current_function_name = func_name
self._function_name_sent = True
# Initialize tool tracking
if self.current_tool_id == -1:
self.current_tool_id = 0
# Ensure tracking arrays are large enough
while len(self.prev_tool_call_arr) <= self.current_tool_id:
self.prev_tool_call_arr.append({})
while len(self.streamed_args_for_tool) <= self.current_tool_id:
self.streamed_args_for_tool.append("")
# Store tool call info
self.prev_tool_call_arr[self.current_tool_id] = {
"name": func_name,
"arguments": {},
}
# Send tool name with empty parameters
calls.append(
ToolCallItem(
tool_index=self.current_tool_id,
name=func_name,
parameters="",
)
)
else:
# Invalid function name
logger.warning(f"Invalid function name: {func_name}")
self._reset_streaming_state()
return StreamingParseResult(calls=calls)
else:
# Function name not complete yet
return StreamingParseResult(calls=calls)
# Parse parameters incrementally
if self._function_name_sent:
# Extract all complete parameters
new_params = {}
for param_match in self.param_regex.finditer(invoke_part):
param_name = param_match.group(1)
param_value = param_match.group(2).strip()
# Use schema-aware parsing
arg_type = get_argument_type(
self._current_function_name, param_name, tools
)
if arg_type and arg_type != "string":
parsed_value, _ = parse_arguments(param_value)
new_params[param_name] = parsed_value
else:
new_params[param_name] = param_value
# Check if we have new parameters to stream
if new_params != self._current_parameters:
# Build the JSON content without the closing brace for streaming
if not self._current_parameters:
# First parameters - send opening brace and content
params_content = json.dumps(new_params, ensure_ascii=False)
if len(params_content) > 2: # More than just "{}"
# Send everything except the closing brace
diff = params_content[:-1]
else:
diff = "{"
else:
# Subsequent parameters - calculate the incremental diff
old_json = json.dumps(self._current_parameters, ensure_ascii=False)
new_json = json.dumps(new_params, ensure_ascii=False)
# Remove closing braces for comparison
old_without_brace = old_json[:-1]
new_without_brace = new_json[:-1]
# The new content should extend the old content
if new_without_brace.startswith(old_without_brace):
diff = new_without_brace[len(old_without_brace) :]
else:
# Parameters changed in unexpected way - shouldn't happen in normal streaming
diff = ""
if diff:
calls.append(
ToolCallItem(
tool_index=self.current_tool_id,
parameters=diff,
)
)
self.streamed_args_for_tool[self.current_tool_id] += diff
# Update current state
self._current_parameters = new_params
self.prev_tool_call_arr[self.current_tool_id]["arguments"] = new_params
# Check if tool call is complete
if self.tool_call_end in self._buffer:
# Send closing brace if we've sent any parameters
if self.streamed_args_for_tool[self.current_tool_id]:
calls.append(
ToolCallItem(
tool_index=self.current_tool_id,
parameters="}",
)
)
self.streamed_args_for_tool[self.current_tool_id] += "}"
# Find the end position
end_idx = self._buffer.find(self.tool_call_end)
# Remove the processed tool call from buffer
self._buffer = self._buffer[end_idx + len(self.tool_call_end) :]
# Reset state for next tool call
self._reset_streaming_state()
self.current_tool_id += 1
return StreamingParseResult(calls=calls)
def _reset_streaming_state(self):
"""Reset streaming state for the next tool call"""
self._in_tool_call = False
self._function_name_sent = False
self._current_function_name = ""
self._current_parameters = {}
def supports_structural_tag(self) -> bool:
"""Return True if this detector supports structural tag format."""
return False
def structure_info(self) -> _GetInfoFunc:
raise NotImplementedError()
def build_ebnf(self, tools: List[Tool]) -> str:
"""
Build EBNF grammar for Step3 tool call format.
"""
# Custom call rule for steptml format
call_rule_fmt = (
'"function" "<|tool_sep|>" "<steptml:invoke name=\\"{name}\\">" '
'{arguments_rule} "</steptml:invoke>"'
)
# Custom key-value rule for steptml parameters
key_value_rule_fmt = (
'"<steptml:parameter name=\\"{key}\\">" {valrule} "</steptml:parameter>"'
)
return EBNFComposer.build_ebnf(
tools,
sequence_start_token=self.bot_token,
sequence_end_token=self.eot_token,
individual_call_start_token=self.tool_call_begin,
individual_call_end_token=self.tool_call_end,
tool_call_separator="",
function_format="xml",
call_rule_fmt=call_rule_fmt,
key_value_rule_fmt=key_value_rule_fmt,
key_value_separator="",
)
......@@ -41,6 +41,7 @@ from sglang.srt.configs import (
ExaoneConfig,
KimiVLConfig,
MultiModalityConfig,
Step3VLConfig,
)
from sglang.srt.configs.internvl import InternVLChatConfig
from sglang.srt.connector import create_remote_connector
......@@ -54,6 +55,7 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
MultiModalityConfig.model_type: MultiModalityConfig,
KimiVLConfig.model_type: KimiVLConfig,
InternVLChatConfig.model_type: InternVLChatConfig,
Step3VLConfig.model_type: Step3VLConfig,
}
for name, cls in _CONFIG_REGISTRY.items():
......
......@@ -165,7 +165,7 @@ def process_content_for_template_format(
new_msg["content"] = processed_content_parts
return new_msg
else: # content_format == "string"
elif content_format == "string":
# String format: flatten to text only (for templates like DeepSeek)
text_parts = []
for chunk in msg_dict["content"]:
......@@ -179,3 +179,6 @@ def process_content_for_template_format(
new_msg["content"] = " ".join(text_parts) if text_parts else ""
new_msg = {k: v for k, v in new_msg.items() if v is not None}
return new_msg
else:
raise ValueError(f"Invalid content format: {content_format}")
......@@ -53,7 +53,7 @@ class TemplateManager:
def __init__(self):
self._chat_template_name: Optional[str] = None
self._completion_template_name: Optional[str] = None
self._jinja_template_content_format: Optional[str] = None
self._jinja_template_content_format: Optional[str] = "openai"
@property
def chat_template_name(self) -> Optional[str]:
......@@ -71,31 +71,60 @@ class TemplateManager:
return self._jinja_template_content_format
def load_chat_template(
self, tokenizer_manager, chat_template_arg: str, model_path: str
self, tokenizer_manager, chat_template_arg: Optional[str], model_path: str
) -> None:
"""
Load a chat template from various sources.
Args:
tokenizer_manager: The tokenizer manager instance
chat_template_arg: Template name or file path
chat_template_arg: Template name, file path, or None to auto-detect
model_path: Path to the model
"""
logger.info(f"Loading chat template: {chat_template_arg}")
if chat_template_arg:
self._load_explicit_chat_template(tokenizer_manager, chat_template_arg)
else:
# Try HuggingFace template first
hf_template = self._resolve_hf_chat_template(tokenizer_manager)
if hf_template:
self._jinja_template_content_format = (
detect_jinja_template_content_format(hf_template)
)
logger.info(
f"Using default HuggingFace chat template with detected content format: {self._jinja_template_content_format}"
)
return
if not chat_template_exists(chat_template_arg):
if not os.path.exists(chat_template_arg):
raise RuntimeError(
f"Chat template {chat_template_arg} is not a built-in template name "
"or a valid chat template file path."
# Fallback to SGLang template guessing
self.guess_chat_template_from_model_path(model_path)
# Set default format if no template was found
if self._chat_template_name is None:
self._jinja_template_content_format = "string"
logger.info(
"No chat template found, defaulting to 'string' content format"
)
if chat_template_arg.endswith(".jinja"):
self._load_jinja_template(tokenizer_manager, chat_template_arg)
else:
self._load_json_chat_template(chat_template_arg)
else:
def _load_explicit_chat_template(
self, tokenizer_manager, chat_template_arg: str
) -> None:
"""Load explicitly specified chat template."""
logger.info(f"Loading chat template from argument: {chat_template_arg}")
if chat_template_exists(chat_template_arg):
self._chat_template_name = chat_template_arg
return
if not os.path.exists(chat_template_arg):
raise RuntimeError(
f"Chat template {chat_template_arg} is not a built-in template name "
"or a valid chat template file path."
)
if chat_template_arg.endswith(".jinja"):
self._load_jinja_template(tokenizer_manager, chat_template_arg)
else:
self._load_json_chat_template(chat_template_arg)
def guess_chat_template_from_model_path(self, model_path: str) -> None:
"""
......@@ -146,10 +175,7 @@ class TemplateManager:
completion_template: Optional completion template name/path
"""
# Load chat template
if chat_template:
self.load_chat_template(tokenizer_manager, chat_template, model_path)
else:
self.guess_chat_template_from_model_path(model_path)
self.load_chat_template(tokenizer_manager, chat_template, model_path)
# Load completion template
if completion_template:
......@@ -166,7 +192,7 @@ class TemplateManager:
chat_template
)
logger.info(
f"Detected chat template content format: {self._jinja_template_content_format}"
f"Detected user specified Jinja chat template with content format: {self._jinja_template_content_format}"
)
def _load_json_chat_template(self, template_path: str) -> None:
......@@ -224,3 +250,20 @@ class TemplateManager:
override=True,
)
self._completion_template_name = template["name"]
def _resolve_hf_chat_template(self, tokenizer_manager) -> Optional[str]:
"""
Resolve HuggingFace chat template.
Returns the chat template string if found, None otherwise.
"""
tokenizer = tokenizer_manager.tokenizer
# Try to get AutoTokenizer chat template
try:
return tokenizer.get_chat_template()
except Exception as e:
logger.debug(f"Error getting chat template via get_chat_template(): {e}")
logger.debug("No HuggingFace chat template found")
return None
import logging
import math
from collections.abc import Iterable
from math import sqrt
from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple, TypedDict, Union
import torch
from torch import nn
from torch.nn import LayerNorm
from torch.nn import functional as F
from transformers import PretrainedConfig
from transformers.activations import ACT2FN
from sglang.srt.configs.step3_vl import (
Step3TextConfig,
Step3VisionEncoderConfig,
Step3VLConfig,
)
from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
)
from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.attention.vision import VisionAttention
from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (
ColumnParallelLinear,
MergedColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.moe.topk import TopK
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternMultimodalTokens,
general_mm_embed_routine,
)
from sglang.srt.managers.schedule_batch import (
Modality,
MultimodalDataItem,
MultimodalInputs,
global_server_args_dict,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import add_prefix, log_info_on_rank0, make_layers
logger = logging.getLogger(__name__)
"""
Text Model
"""
class Step3TextMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size,
[intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=add_prefix("gate_up_proj", prefix),
)
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=add_prefix("down_proj", prefix),
)
if hidden_act != "silu":
raise ValueError(
f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now."
)
self.act_fn = SiluAndMul()
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class Step3TextMoEMLP(nn.Module):
# Native
def __init__(
self,
layer_id: int,
config: Step3TextConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
self.layer_id = layer_id
if self.tp_size > config.moe_num_experts:
raise ValueError(
f"Tensor parallel size {self.tp_size} is greater than "
f"the number of experts {config.moe_num_experts}."
)
self.topk = TopK(
top_k=config.moe_top_k,
renormalize=config.norm_expert_weight,
use_grouped_topk=False,
)
self.experts = get_moe_impl_class()(
num_experts=config.moe_num_experts,
top_k=config.moe_top_k,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
layer_id=layer_id,
quant_config=quant_config,
prefix=add_prefix("experts", prefix),
)
self.gate = ReplicatedLinear(
config.hidden_size,
output_size=config.moe_num_experts,
bias=False,
quant_config=None,
prefix=add_prefix("gate", prefix),
)
if global_server_args_dict["enable_deepep_moe"]:
raise NotImplementedError("DeepEP MoE is not supported yet in Step3 model.")
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
router_logits, _ = self.gate(hidden_states)
topk_output = self.topk(hidden_states, router_logits)
final_hidden_states = self.experts(
hidden_states=hidden_states, topk_output=topk_output
)
if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states.view(num_tokens, hidden_dim)
class Step3TextAttention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
head_dim: int,
share_q_dim: int,
layer_id: int = 0,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None,
rms_norm_eps=None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = hidden_size
attn_tp_rank = get_attention_tp_rank()
attn_tp_size = get_attention_tp_size()
self.all_tp_rank = get_tensor_model_parallel_rank()
self.total_num_heads = num_heads
self.attn_tp_rank = attn_tp_rank
self.layer_id = layer_id
assert self.total_num_heads % attn_tp_size == 0
self.num_heads = self.total_num_heads // attn_tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= attn_tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % attn_tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert attn_tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // attn_tp_size)
self.head_dim = head_dim
self.q_size = share_q_dim if share_q_dim else head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.qkv_proj = MergedColumnParallelLinear(
hidden_size,
[self.q_size, self.kv_size, self.kv_size],
bias=False,
quant_config=quant_config,
tp_rank=0, # In fact, we need a MergedReplicatedLinear
tp_size=1,
prefix=add_prefix("qkv_proj", prefix),
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
quant_config=quant_config,
tp_rank=attn_tp_rank,
tp_size=attn_tp_size,
reduce_results=False,
prefix=add_prefix("o_proj", prefix),
)
self.inter_norm = RMSNorm(self.q_size, eps=rms_norm_eps)
self.wq = ColumnParallelLinear(
self.q_size,
self.head_dim * self.total_num_heads,
bias=False,
quant_config=quant_config,
tp_rank=attn_tp_rank,
tp_size=attn_tp_size,
prefix=add_prefix("wq", prefix),
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
)
self.attn = RadixAttention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
layer_id=layer_id,
quant_config=quant_config,
prefix=add_prefix("attn", prefix),
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q = self.inter_norm(q.contiguous())
q, _ = self.wq(q)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, forward_batch)
output, _ = self.o_proj(attn_output)
return output
class Step3TextDecoderLayer(nn.Module):
def __init__(
self,
config: Step3TextConfig,
layer_id: int,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
head_dim = getattr(
config, "head_dim", config.hidden_size // config.num_attention_heads
)
# TODO: support shared experts fusion
# self.n_shared_experts = 1
# self.num_fused_shared_experts = (
# 0
# if global_server_args_dict["disable_shared_experts_fusion"]
# else self.n_shared_experts
# )
self.num_fused_shared_experts = 0
rms_norm_eps = config.rms_norm_eps
self.self_attn = Step3TextAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=1,
head_dim=head_dim,
share_q_dim=config.share_q_dim,
layer_id=layer_id,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
rms_norm_eps=rms_norm_eps,
quant_config=quant_config,
prefix=add_prefix("self_attn", prefix),
)
moe_layers_enum = getattr(config, "moe_layers_enum", None)
if moe_layers_enum is not None:
moe_layers_idx = [int(i) for i in moe_layers_enum.strip().split(",")]
else:
# Default to 1dense.
moe_layers_idx = [i for i in range(1, config.num_hidden_layers)]
self.use_moe = False
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
self.layer_id = layer_id
self.is_layer_sparse = True if layer_id in moe_layers_idx else False
self.is_previous_layer_sparse = (
True if layer_id - 1 in moe_layers_idx else False
)
self.layer_scatter_modes = LayerScatterModes.init_new(
layer_id=layer_id,
num_layers=config.num_hidden_layers,
is_layer_sparse=self.is_layer_sparse,
is_previous_layer_sparse=self.is_previous_layer_sparse,
)
if not self.is_layer_sparse:
self.mlp = Step3TextMLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act="silu",
quant_config=quant_config,
prefix=add_prefix("mlp", prefix),
)
else:
self.use_moe = True
if self.num_fused_shared_experts == 0:
self.moe = Step3TextMoEMLP(
layer_id=layer_id,
config=config,
quant_config=quant_config,
prefix=add_prefix("mlp", prefix),
)
self.share_expert = Step3TextMLP(
hidden_size=config.hidden_size,
intermediate_size=config.share_expert_dim,
hidden_act="silu",
quant_config=quant_config,
prefix=add_prefix("share_expert", prefix),
)
else:
self.moe = Step3TextMoEMLP(
layer_id=layer_id,
config=config,
quant_config=quant_config,
prefix=add_prefix("mlp", prefix),
)
self.layer_communicator = LayerCommunicator(
layer_scatter_modes=self.layer_scatter_modes,
input_layernorm=self.input_layernorm,
post_attention_layernorm=self.post_attention_layernorm,
)
def moe_mlp_forward(self, hidden_states):
if not self.num_fused_shared_experts:
h = hidden_states.clone()
hidden_states = self.moe(hidden_states)
hidden_states += self.share_expert(h)
else:
hidden_states = self.moe(hidden_states)
return hidden_states
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
hidden_states, residual = self.layer_communicator.prepare_attn(
hidden_states, residual, forward_batch
)
if hidden_states.shape[0] != 0:
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
forward_batch=forward_batch,
)
hidden_states, residual = self.layer_communicator.prepare_mlp(
hidden_states, residual, forward_batch
)
if self.use_moe:
hidden_states = self.moe_mlp_forward(hidden_states)
else:
hidden_states = self.mlp(hidden_states)
hidden_states, residual = self.layer_communicator.postprocess_layer(
hidden_states, residual, forward_batch
)
return hidden_states, residual
class Step3TextModel(nn.Module):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
enable_tp=not global_server_args_dict["enable_dp_attention"],
prefix=add_prefix("embed_tokens", prefix),
)
self.layers = make_layers(
config.num_hidden_layers,
lambda idx, prefix: Step3TextDecoderLayer(
layer_id=idx,
config=config,
quant_config=quant_config,
prefix=prefix,
),
prefix=add_prefix("layers", prefix),
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def get_input_embeddings(self):
return self.embed_tokens
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
if input_embeds is None:
hidden_states = self.embed_tokens(input_ids)
else:
hidden_states = input_embeds
residual = None
for i in range(len(self.layers)):
layer = self.layers[i]
hidden_states, residual = layer(
positions, hidden_states, forward_batch, residual
)
if hidden_states.shape[0] != 0:
if residual is None:
hidden_states = self.norm(hidden_states)
else:
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
"""
Vision Model
"""
def get_abs_pos(abs_pos, tgt_size):
dim = abs_pos.size(-1)
abs_pos_new = abs_pos.squeeze(0)
cls_token, old_pos_embed = abs_pos_new[:1], abs_pos_new[1:]
src_size = int(math.sqrt(abs_pos_new.shape[0] - 1))
tgt_size = int(math.sqrt(tgt_size))
dtype = abs_pos.dtype
if src_size != tgt_size:
old_pos_embed = (
old_pos_embed.view(1, src_size, src_size, dim)
.permute(0, 3, 1, 2)
.contiguous()
)
old_pos_embed = old_pos_embed.to(torch.float32)
new_pos_embed = F.interpolate(
old_pos_embed,
size=(tgt_size, tgt_size),
mode="bicubic",
antialias=True,
align_corners=False,
).to(dtype)
new_pos_embed = new_pos_embed.permute(0, 2, 3, 1)
new_pos_embed = new_pos_embed.view(tgt_size * tgt_size, dim)
vision_pos_embed = torch.cat([cls_token, new_pos_embed], dim=0)
vision_pos_embed = vision_pos_embed.view(1, tgt_size * tgt_size + 1, dim)
return vision_pos_embed
else:
return abs_pos
class Step3VisionMLP(nn.Module):
def __init__(
self,
dim: int,
intermediate_size: int,
bias: bool = True,
hidden_act="quick_gelu",
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.fc1 = ColumnParallelLinear(
dim,
intermediate_size,
bias=bias,
quant_config=quant_config,
prefix=add_prefix("gate_proj", prefix),
)
self.act = ACT2FN[hidden_act] # quick_gelu
self.fc2 = RowParallelLinear(
intermediate_size,
dim,
bias=bias,
quant_config=quant_config,
prefix=add_prefix("down_proj", prefix),
)
def forward(self, hidden_states) -> torch.Tensor:
hidden_states, _ = self.fc1(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states, _ = self.fc2(hidden_states)
return hidden_states
class Step3VisionAttention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 16,
qkv_backend="fa3",
quant_config=None,
prefix: str = "",
) -> None:
super().__init__()
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.out_proj = RowParallelLinear(
dim,
dim,
bias=True,
quant_config=quant_config,
prefix=add_prefix("out_proj", prefix),
)
self.scale = self.head_dim**-0.5
self.attn = VisionAttention(
embed_dim=dim,
num_heads=num_heads,
projection_size=dim,
use_qkv_parallel=True,
rotary_embed="normal",
proj_bias=True,
qkv_backend=qkv_backend,
quant_config=quant_config,
prefix=add_prefix("attn", prefix),
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
attn_output = self.attn(hidden_states)
return attn_output
class Step3VisionEmbeddings(nn.Module):
def __init__(self, config: Step3VisionEncoderConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.image_size = config.image_size
self.patch_size = config.patch_size
self.class_embedding = nn.Parameter(torch.randn(1, self.embed_dim))
self.patch_embedding = nn.Conv2d(
in_channels=config.num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
stride=self.patch_size,
bias=True,
)
self.num_patches = (self.image_size // self.patch_size) ** 2
self.pad_tp_size = 4 # hard code for padding
# To load the pretrained weights, we still use P+1 as the seqlen
self.position_embedding = torch.nn.Embedding(
self.num_patches + 1, self.embed_dim
)
self.register_buffer(
"position_ids",
torch.arange(self.num_patches + 1).expand((1, -1)),
persistent=False,
)
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
batch_size = pixel_values.shape[0]
patch_embeds = self.patch_embedding(
pixel_values
) # shape = [*, width, grid, grid]
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
# pad
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
embeddings = embeddings + get_abs_pos(
self.position_embedding(self.position_ids), patch_embeds.size(1)
)
embeddings = torch.cat(
[
embeddings[:, 0, :].unsqueeze(1).repeat(1, self.pad_tp_size - 1, 1),
embeddings,
],
dim=1,
)
return embeddings
class Step3VisionEncoderLayer(nn.Module):
def __init__(self, config, attn_implementation: str = "sdpa") -> None:
super().__init__()
self.embed_dim = config.hidden_size
self.layer_norm1 = LayerNorm(self.embed_dim, eps=1e-6)
self.layer_norm2 = LayerNorm(self.embed_dim, eps=1e-6)
self.self_attn = Step3VisionAttention(
self.embed_dim, num_heads=config.num_attention_heads
)
self.mlp = Step3VisionMLP(
dim=self.embed_dim,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
)
def forward(self, hidden_states) -> torch.Tensor:
hidden_states = hidden_states + self.layer_norm1(self.self_attn(hidden_states))
hidden_states = hidden_states + self.layer_norm2(self.mlp(hidden_states))
return hidden_states
class Step3VisionTransformer(nn.Module):
def __init__(self, config: Step3VisionEncoderConfig):
super().__init__()
self.config = config
self.image_size = config.image_size
self.embeddings = Step3VisionEmbeddings(config)
self.transformer = Step3VisionEncoder(config)
@property
def dtype(self) -> torch.dtype:
return self.embeddings.patch_embedding.weight.dtype
def forward(
self,
pixel_values: torch.Tensor,
):
hidden_states = self.embeddings(pixel_values)
hidden_states = self.transformer(inputs_embeds=hidden_states)
return hidden_states
class Step3VisionEncoder(nn.Module):
"""
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
[`Step3VisionEncoderLayer`].
Args:
config: StepVisionEncoderConfig
"""
def __init__(self, config: Step3VisionEncoderConfig):
super().__init__()
self.config = config
self.layers = nn.ModuleList(
[Step3VisionEncoderLayer(config) for _ in range(config.num_hidden_layers)]
)
def forward(
self,
inputs_embeds,
) -> torch.Tensor:
hidden_states = inputs_embeds
for encoder_layer in self.layers:
hidden_states = encoder_layer(
hidden_states,
)
return hidden_states
class Step3VLForConditionalGeneration(nn.Module):
def __init__(
self,
config: Step3VLConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.quant_config = quant_config
self.model = Step3TextModel(
config.text_config, quant_config, prefix=add_prefix("model", prefix)
)
self.vision_model = Step3VisionTransformer(config.vision_config)
self.vit_downsampler = nn.Conv2d(
config.vision_config.hidden_size,
config.vision_config.output_hidden_size,
kernel_size=2,
stride=config.understand_projector_stride,
)
self.vit_downsampler2 = nn.Conv2d(
config.vision_config.output_hidden_size,
config.vision_config.output_hidden_size * 2,
kernel_size=3,
stride=2,
padding=1,
)
self.vit_large_projector = nn.Linear(
config.vision_config.output_hidden_size * 2,
config.hidden_size,
bias=config.projector_bias,
)
# TODO: support shared experts fusion
# self.n_shared_experts = 1
# self.num_fused_shared_experts = (
# 0
# if global_server_args_dict["disable_shared_experts_fusion"]
# else self.n_shared_experts
# )
self.num_fused_shared_experts = 0
self.config.tie_word_embeddings = False
if getattr(self.config, "tie_word_embeddings", False):
self.lm_head = self.model.embed_tokens
else:
self.lm_head = ParallelLMHead(
config.text_config.vocab_size,
config.text_config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
)
self.logits_processor = LogitsProcessor(config.text_config)
def _get_vision_model_output(self, input_tensor: torch.Tensor) -> torch.Tensor:
return self.vision_model(input_tensor)[:, 4:]
def _flatten_embeddings(self, embeddings) -> torch.Tensor:
if isinstance(embeddings, torch.Tensor):
# Flatten all but the last dimension.
return embeddings.flatten(0, -2)
return torch.cat(tuple(self._flatten_embeddings(t) for t in embeddings))
def _process_image_features(self, image_features: torch.Tensor) -> torch.Tensor:
B, P = image_features.shape[:2]
HW = int(sqrt(P))
image_features = image_features.permute(0, 2, 1).view(B, -1, HW, HW)
image_features = self.vit_downsampler(image_features)
image_features = self.vit_downsampler2(image_features)
n_dim = image_features.size(1)
image_features = image_features.view(B, n_dim, -1).permute(0, 2, 1)
image_features = self.vit_large_projector(image_features)
return image_features
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
assert len(items) == 1 # We only have images.
item = items[0]
pixel_values = item.feature.type(self.vision_model.dtype)
num_patches = item.model_specific_data.get("num_patches")
patch_pixel_values = item.model_specific_data.get("patch_pixel_values", None)
if patch_pixel_values is not None:
patch_pixel_values = patch_pixel_values.type(self.vision_model.dtype)
if patch_pixel_values is not None:
patch_pixel_values = patch_pixel_values.to("cuda")
image_features = self._get_vision_model_output(pixel_values)
patch_image_features = (
self._get_vision_model_output(patch_pixel_values)
if patch_pixel_values is not None
else None
)
image_features = self._process_image_features(image_features)
patch_image_features = (
self._process_image_features(patch_image_features)
if patch_image_features is not None
else None
)
merged_image_features = []
cur_patch_idx = 0
for i, num_patch in enumerate(num_patches):
cur_feature = []
if num_patch > 0:
patch_slice = patch_image_features[
cur_patch_idx : cur_patch_idx + num_patch
]
cur_feature.append(patch_slice.view(-1, patch_slice.shape[-1]))
cur_feature.append(image_features[i].view(-1, image_features.shape[-1]))
cur_patch_idx += num_patch
merged_image_features.append(
torch.cat(cur_feature) if len(cur_feature) > 1 else cur_feature[0]
)
return self._flatten_embeddings(merged_image_features)
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
pattern = MultiModalityDataPaddingPatternMultimodalTokens()
return pattern.pad_input_tokens(input_ids, mm_inputs)
@torch.no_grad()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
hidden_states = general_mm_embed_routine(
input_ids=input_ids,
forward_batch=forward_batch,
language_model=self.model,
data_embedding_funcs={
Modality.IMAGE: self.get_image_feature,
},
positions=positions,
)
return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# TODO:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", 0),
(".qkv_proj", ".k_proj", 1),
(".qkv_proj", ".v_proj", 2),
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
if self.num_fused_shared_experts > 0:
assert self.num_fused_shared_experts == 1
log_info_on_rank0(logger, "Shared experts fusion optimization enabled.")
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.text_config.moe_num_experts
+ self.num_fused_shared_experts,
)
params_dict = dict(self.named_parameters())
loaded_params = set()
def match_expert_and_shard_ids(name_path: str, weight_path: str) -> bool:
name_parts = name_path.split(".")
weight_parts = weight_path.split(".")
shard_id_matches = name_parts[4] == weight_parts[2]
return shard_id_matches
for name, loaded_weight in weights:
if "vision_model" in name:
# 1.It’s not great, but let’s leave it like this for now
name = name.replace("self_attn", "self_attn.attn")
# 2.
name = name.replace("out_proj", "proj")
# TODO: support vision model
if self.num_fused_shared_experts > 0 and "share" in name:
# assert False
name = name.replace("share_expert", "moe")
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if (
expert_id != self.config.text_config.moe_num_experts
or not match_expert_and_shard_ids(name, weight_name)
):
continue
part_name = weight_name.split(".")[-2]
fake_weight_name = name.replace(part_name, weight_name[:-1])
actual_param_name = name.replace(part_name + ".", param_name)
param = params_dict[actual_param_name]
weight_loader = param.weight_loader
weight_loader(
param,
loaded_weight,
name,
shard_id=shard_id,
expert_id=expert_id,
)
break
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
if "gate." not in name and "moe" in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
loaded_params.add(name)
break
else:
if "moe" not in name:
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
loaded_params.add(name)
else:
if "gate." in name:
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight)
loaded_params.add(name)
continue
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if expert_id == self.config.text_config.moe_num_experts:
continue
if not match_expert_and_shard_ids(name, weight_name):
continue
part_name = weight_name.split(".")[-2]
fake_weight_name = name.replace(part_name, weight_name[:-1])
actual_param_name = name.replace(part_name + ".", param_name)
param = params_dict[actual_param_name]
weight_loader = param.weight_loader
weight_loader(
param,
loaded_weight[expert_id],
name,
shard_id=shard_id,
expert_id=expert_id,
)
loaded_params.add(actual_param_name)
# Don't break here, because this 'loaded_weight' includes all the weights for this layer
@classmethod
def get_model_config_for_expert_location(cls, config: Step3VLConfig):
return ModelConfigForExpertLocation(
num_layers=config.text_config.num_hidden_layers,
num_logical_experts=config.text_config.moe_num_experts,
num_groups=None,
)
EntryClass = Step3VLForConditionalGeneration
......@@ -176,6 +176,8 @@ class BaseMultimodalProcessor(ABC):
"image_grid_hws": Modality.IMAGE,
"aspect_ratio_ids": Modality.IMAGE,
"aspect_ratio_mask": Modality.IMAGE,
"num_patches": Modality.IMAGE,
"patch_pixel_values": Modality.IMAGE,
# Audio-related attributes
"audio_features": Modality.AUDIO,
"audio_feature_lens": Modality.AUDIO,
......
import math
import re
from itertools import product
from typing import List, Literal, Optional, TypedDict, Union
import numpy as np
import torch
from PIL import Image
from torchvision import transforms
from torchvision.transforms import InterpolationMode
from transformers import BatchFeature, TensorType
from sglang.srt.models.step3_vl import Step3VLForConditionalGeneration
from sglang.srt.multimodal.processors.base_processor import (
BaseMultimodalProcessor as SGLangBaseProcessor,
)
from sglang.srt.multimodal.processors.base_processor import MultimodalSpecialTokens
ImageWithPatches = tuple[Image.Image, list[Image.Image], list[int] | None]
class GPUToTensor(torch.nn.Module):
def forward(self, raw_image: Union[np.ndarray, Image.Image]) -> torch.Tensor:
if isinstance(raw_image, Image.Image):
return transforms.ToTensor()(raw_image)
if raw_image.ndim == 2:
raw_image = raw_image[:, :, None].repeat(3, -1)
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
image_tensor = torch.from_numpy(raw_image).to(device)
image_tensor = torch.permute(image_tensor, (2, 0, 1)).contiguous()
if image_tensor.dtype == torch.uint8:
image_tensor = image_tensor.to(torch.float32).div(255)
return image_tensor
class Step3VisionProcessor:
def __init__(self, size, interpolation_mode="bicubic", patch_size=None):
mean = [0.48145466, 0.4578275, 0.40821073]
std = [0.26862954, 0.26130258, 0.27577711]
patch_size = patch_size if patch_size is not None else size
self.transform = transforms.Compose(
[
GPUToTensor(),
transforms.Normalize(mean, std),
transforms.Resize(
(size, size),
interpolation=(
InterpolationMode.BICUBIC
if interpolation_mode == "bicubic"
else InterpolationMode.BILINEAR
),
antialias=True,
),
]
)
self.patch_transform = (
transforms.Compose(
[
GPUToTensor(),
transforms.Normalize(mean, std),
transforms.Resize(
(patch_size, patch_size),
interpolation=(
InterpolationMode.BICUBIC
if interpolation_mode == "bicubic"
else InterpolationMode.BILINEAR
),
antialias=True,
),
]
)
if patch_size is not None
else None
)
def __call__(self, image, is_patch=False):
if is_patch:
return {"pixel_values": self.patch_transform(image).unsqueeze(0)}
else:
return {"pixel_values": self.transform(image).unsqueeze(0)}
class ImagePatcher:
def determine_window_size(self, long: int, short: int) -> int:
if long <= 728:
return short if long / short > 1.5 else 0
return min(short, 504) if long / short > 4 else 504
def slide_window(
self,
width: int,
height: int,
sizes: list[tuple[int, int]],
steps: list[tuple[int, int]],
img_rate_thr: float = 0.6,
) -> tuple[list[tuple[int, int, int, int]], tuple[int, int]]:
assert 1 >= img_rate_thr >= 0, "The `in_rate_thr` should lie in 0~1"
windows = []
# Sliding windows.
for size, step in zip(sizes, steps):
size_w, size_h = size
step_w, step_h = step
x_num = 1 if width <= size_w else math.ceil((width - size_w) / step_w + 1)
x_start = [step_w * i for i in range(x_num)]
if len(x_start) > 1 and x_start[-1] + size_w > width:
x_start[-1] = width - size_w
y_num = 1 if height <= size_h else math.ceil((height - size_h) / step_h + 1)
y_start = [step_h * i for i in range(y_num)]
if len(y_start) > 1 and y_start[-1] + size_h > height:
y_start[-1] = height - size_h
start = np.array(list(product(y_start, x_start)), dtype=int)
start[:, [0, 1]] = start[:, [1, 0]]
windows.append(np.concatenate([start, start + size], axis=1))
windows = np.concatenate(windows, axis=0)
return [
(int(box[0]), int(box[1]), int(box[2] - box[0]), int(box[3] - box[1]))
for box in windows
], (x_num, y_num)
def square_pad(self, img: Image.Image) -> Image.Image:
w, h = img.size
if w == h:
return img
size = max(w, h)
padded = Image.new(img.mode, (size, size), 0)
padded.paste(img, (0, 0))
return padded
def get_image_size_for_padding(
self, img_width: int, img_height: int
) -> tuple[int, int]:
ratio = img_width / img_height
if min(img_height, img_width) < 32 and (ratio > 4 or ratio < 1 / 4):
new_size = max(img_height, img_width)
return new_size, new_size
return img_width, img_height
def get_image_size_for_preprocess(
self, img_width: int, img_height: int
) -> tuple[int, int]:
if max(img_height, img_width) > 3024:
scale_factor = 3024 / max(img_height, img_width)
img_width = int(img_width * scale_factor)
img_height = int(img_height * scale_factor)
return img_width, img_height
else:
return img_width, img_height
def get_image_size_for_crop(
self, img_width: int, img_height: int, window_size: int
):
w_ratio = img_width / window_size
h_ratio = img_height / window_size
if w_ratio < 1:
width_new = img_width
else:
decimal_w = w_ratio - img_width // window_size
w_ratio = int(w_ratio) + 1 if decimal_w > 0.2 else int(w_ratio)
width_new = window_size * w_ratio
if h_ratio < 1:
height_new = img_height
else:
decimal_h = h_ratio - img_height // window_size
h_ratio = int(h_ratio) + 1 if decimal_h > 0.2 else int(h_ratio)
height_new = window_size * h_ratio
return int(width_new), int(height_new)
def patch_crop(self, img: Image.Image, i: int, j: int, th: int, tw: int):
target = img.crop((j, i, j + tw, i + th))
return target
def get_num_patches(self, img_width: int, img_height: int) -> tuple[int, int]:
img_width, img_height = self.get_image_size_for_padding(img_width, img_height)
img_width, img_height = self.get_image_size_for_preprocess(
img_width, img_height
)
window_size = self.determine_window_size(
max(img_height, img_width), min(img_height, img_width)
)
if window_size == 0:
return 0, 0
else:
img_width, img_height = self.get_image_size_for_crop(
img_width, img_height, window_size
)
center_list, (x_num, y_num) = self.slide_window(
img_width,
img_height,
[(window_size, window_size)],
[(window_size, window_size)],
)
full_rows = (len(center_list) - 1) // x_num + 1
if len(center_list) > 0 and len(center_list) % x_num == 0:
full_rows -= 1
return len(center_list), full_rows
def __call__(
self, img: Image.Image
) -> tuple[Image.Image, list[Image.Image], list[bool] | None]:
img_width, img_height = img.size
new_img_width, new_img_height = self.get_image_size_for_padding(
img_width, img_height
)
if new_img_width != img_width or new_img_height != img_height:
img = self.square_pad(img)
img_width, img_height = img.size
new_img_width, new_img_height = self.get_image_size_for_preprocess(
img_width, img_height
)
img = img.resize((new_img_width, new_img_height), Image.Resampling.BILINEAR)
window_size = self.determine_window_size(
max(new_img_height, new_img_width), min(new_img_height, new_img_width)
)
if window_size == 0:
return img, [], None
else:
new_img_width, new_img_height = self.get_image_size_for_crop(
new_img_width, new_img_height, window_size
)
if (new_img_width, new_img_height) != (img_width, img_height):
img_for_crop = img.resize(
(new_img_width, new_img_height), Image.Resampling.BILINEAR
)
else:
img_for_crop = img
patches = []
newlines = []
center_list, (x_num, y_num) = self.slide_window(
new_img_width,
new_img_height,
[(window_size, window_size)],
[(window_size, window_size)],
)
for patch_id, center_lf_point in enumerate(center_list):
x, y, patch_w, patch_h = center_lf_point
big_patch = self.patch_crop(img_for_crop, y, x, patch_h, patch_w)
patches.append(big_patch)
if (patch_id + 1) % x_num == 0:
newlines.append(patch_id)
if newlines and newlines[-1] == len(patches) - 1:
newlines.pop()
return (
img,
patches,
(
[i in newlines for i in range(len(patches))]
if len(patches) > 0
else None
),
)
class Step3VLProcessor:
def __init__(
self,
config,
tokenizer,
) -> None:
super().__init__()
self.config = config
self.tokenizer = tokenizer
self.image_size = 728
self.patch_size = 504
self.image_preprocessor = Step3VisionProcessor(
self.image_size, "bilinear", self.patch_size
)
self.num_image_feature_size = 169
self.num_patch_feature_size = 81
self.image_token = "<im_patch>"
self.image_feature_placeholder = self.image_token * self.num_image_feature_size
self.patch_feature_placeholder = self.image_token * self.num_patch_feature_size
self.patcher = ImagePatcher()
@property
def image_token_id(self) -> int:
return self.tokenizer.get_vocab()[self.image_token]
def get_num_image_tokens(self, img_width: int, img_height: int) -> int:
num_patches, num_newlines = self.patcher.get_num_patches(img_width, img_height)
return (
num_patches * (self.num_patch_feature_size + 2)
+ self.num_image_feature_size
+ 2
+ num_newlines
)
def _split_images(self, images: list[Image.Image]) -> list[ImageWithPatches]:
result = []
for img in images:
result.append(self.patcher(img))
return result
def _convert_images_to_pixel_values(
self,
images: list[Image.Image],
is_patch: bool = False,
) -> list[torch.Tensor]:
return [
self.image_preprocessor(img, is_patch=is_patch)["pixel_values"]
for img in images
]
def _get_patch_repl(
self,
num_patches: int,
patch_newline_mask: list[bool] | None,
) -> tuple[str, list[int]]:
text = ""
token_ids = []
for i in range(num_patches):
assert len(patch_newline_mask) == num_patches
text += f"<patch_start>{self.patch_feature_placeholder}<patch_end>"
token_ids.extend(
[self.tokenizer.convert_tokens_to_ids("<patch_start>")]
+ [self.image_token_id] * self.num_patch_feature_size
+ [self.tokenizer.convert_tokens_to_ids("<patch_end>")]
)
if patch_newline_mask and patch_newline_mask[i]:
text += "<patch_newline>"
token_ids.append(
self.tokenizer.convert_tokens_to_ids("<patch_newline>")
)
return text, token_ids
def _get_image_repl(
self,
num_images: int,
) -> tuple[str, list[int]]:
text = f"<im_start>{self.image_feature_placeholder}<im_end>"
token_ids = (
[self.tokenizer.convert_tokens_to_ids("<im_start>")]
+ [self.image_token_id] * self.num_image_feature_size
+ [self.tokenizer.convert_tokens_to_ids("<im_end>")]
)
return text * num_images, token_ids * num_images
def _get_image_repl_features(
self,
num_images: int,
num_patches: int,
patch_new_line_idx: Optional[list[bool]],
) -> tuple[str, list[int]]:
if num_patches > 0:
patch_repl, patch_repl_ids = self._get_patch_repl(
num_patches, patch_new_line_idx
)
else:
patch_repl = ""
patch_repl_ids = []
image_repl, image_repl_ids = self._get_image_repl(num_images)
return patch_repl + image_repl, patch_repl_ids + image_repl_ids
def replace_placeholder(self, text: str, placeholder: str, repls: list[str]) -> str:
parts = text.split(placeholder)
if len(parts) - 1 != len(repls):
raise ValueError(
"The number of placeholders does not match the number of replacements." # noqa: E501
)
result = [parts[0]]
for i, repl in enumerate(repls):
result.append(repl)
result.append(parts[i + 1])
return "".join(result)
def __call__(
self,
text: Optional[Union[str, list[str]]] = None,
images: Optional[Union[Image.Image, list[Image.Image]]] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
*args,
**kwargs,
) -> BatchFeature:
if text is None:
text = []
if not isinstance(text, list):
text = [text]
if images is None:
images = []
if not isinstance(images, list):
images = [images]
if len(images) == 0:
image_inputs = {}
text_inputs = self.tokenizer(text)
else:
splitted_images_data = self._split_images(images)
pixel_values_lst = []
patch_pixel_values_lst = []
patch_newline_mask_lst = []
image_repl_str_lst = []
image_repl_ids_lst = []
num_patches = []
for (
raw_img,
img_patches,
patch_newline_mask,
) in splitted_images_data: # noqa: E501
pixel_values_lst.extend(self._convert_images_to_pixel_values([raw_img]))
if len(img_patches) > 0:
patch_pixel_values_lst.extend(
self._convert_images_to_pixel_values(img_patches, is_patch=True)
)
num_patches.append(len(img_patches))
image_repl_str, image_repl_ids = self._get_image_repl_features(
1, len(img_patches), patch_newline_mask
)
image_repl_str_lst.append(image_repl_str)
image_repl_ids_lst.extend(image_repl_ids)
if patch_newline_mask is not None:
patch_newline_mask_lst.extend(patch_newline_mask)
image_inputs = {
"pixel_values": torch.cat(pixel_values_lst),
"num_patches": num_patches,
}
if patch_pixel_values_lst:
image_inputs["patch_pixel_values"] = torch.cat(patch_pixel_values_lst)
if patch_newline_mask_lst:
image_inputs["patch_newline_mask"] = torch.tensor(
patch_newline_mask_lst, dtype=torch.bool
)
text = [
self.replace_placeholder(t, self.image_token, image_repl_str_lst)
for t in text
]
text_inputs = self.tokenizer(text)
return BatchFeature(
{
**text_inputs,
**image_inputs,
},
tensor_type=return_tensors,
)
################################################
class Step3VLImageProcessor(SGLangBaseProcessor):
models = [Step3VLForConditionalGeneration]
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
# TODO, check _processor is tokenizer or processor.
processor = Step3VLProcessor(hf_config, _processor)
super().__init__(hf_config, server_args, processor, *args, **kwargs)
self.IM_TOKEN_ID = 128001
self.mm_tokens = MultimodalSpecialTokens(
image_token="<im_patch>",
image_token_id=128001,
image_token_regex=re.compile(r"(?:<im_patch>)"),
).build(_processor)
mean = [0.48145466, 0.4578275, 0.40821073]
std = [0.26862954, 0.26130258, 0.27577711]
def preprocess(self, image):
return {"pixel_values": self.transform(image).unsqueeze(0)}
def __call__(self, image):
return self.preprocess(image)
async def process_mm_data_async(
self,
image_data: List[Union[str, bytes]],
input_text: str | List[int],
request_obj,
*args,
**kwargs,
):
base_output = self.load_mm_data(
prompt=input_text,
image_data=image_data,
video_data=request_obj.video_data,
multimodal_tokens=self.mm_tokens,
)
mm_items, input_ids, ret = self.process_and_combine_mm_data(
base_output, self.mm_tokens
)
return {
"input_ids": input_ids.tolist(),
"mm_items": mm_items,
"im_token_id": self.mm_tokens.image_token_id,
}
......@@ -105,7 +105,7 @@ class BaseReasoningFormatDetector:
# If we're not in a reasoning block return as normal text
if not self._in_reasoning:
self._buffer = ""
return StreamingParseResult(normal_text=new_text)
return StreamingParseResult(normal_text=current_text)
return StreamingParseResult()
......@@ -233,6 +233,7 @@ class ReasoningParser:
"qwen3-thinking": Qwen3ThinkingDetector,
"glm45": Qwen3Detector,
"kimi": KimiDetector,
"step3": DeepSeekR1Detector,
}
def __init__(self, model_type: Optional[str] = None, stream_reasoning: bool = True):
......
......@@ -1117,9 +1117,10 @@ class ServerArgs:
"kimi_k2",
"qwen3_coder",
"glm45",
"step3",
],
default=ServerArgs.tool_call_parser,
help="Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', 'llama3', 'deepseekv3', 'pythonic', 'kimi_k2', and 'qwen3_coder'.",
help="Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', 'llama3', 'deepseekv3', 'pythonic', 'kimi_k2', 'qwen3_coder', 'glm45', and 'step3'.",
)
# Data parallelism
......
......@@ -493,5 +493,117 @@ class TestIntegrationScenarios(CustomTestCase):
self.assertIn("final answer", all_normal)
class TestBufferLossBugFix(CustomTestCase):
"""Test cases for the buffer loss bug fix in parse_streaming_increment."""
def test_partial_end_tag_buffer_loss_bug(self):
"""
Test the bug where partial end tag fragments are lost when followed by normal text.
Bug scenario:
1. _in_reasoning is False
2. new_text is "</" (part of closing thinking tag)
3. Fragment is stored in buffer and empty string is returned
4. Next step: new_text is "answer", _in_reasoning still False
5. Buffer is cleared and "answer" is returned directly
6. The "</" from previous step is lost
This test verifies the fix where line 108 was changed from:
return StreamingParseResult(normal_text=new_text)
to:
return StreamingParseResult(normal_text=current_text)
"""
detector = BaseReasoningFormatDetector("<think>", "</think>")
# Step 1: Send partial end tag when not in reasoning mode
# This should be buffered since it could be start of "</think>"
result1 = detector.parse_streaming_increment("</")
self.assertEqual(result1.normal_text, "")
self.assertEqual(result1.reasoning_text, "")
# Step 2: Send normal text that doesn't complete the end tag
# Before fix: would return only "answer", losing the "</"
# After fix: should return the complete buffered content "</answer"
result2 = detector.parse_streaming_increment("answer")
self.assertEqual(result2.normal_text, "</answer")
self.assertEqual(result2.reasoning_text, "")
def test_partial_start_tag_buffer_preservation(self):
"""
Test that partial start tag fragments are properly preserved.
"""
detector = BaseReasoningFormatDetector("<think>", "</think>")
# Send partial start tag
result1 = detector.parse_streaming_increment("<th")
self.assertEqual(result1.normal_text, "")
self.assertEqual(result1.reasoning_text, "")
# Complete with non-matching text
result2 = detector.parse_streaming_increment("is is text")
self.assertEqual(result2.normal_text, "<this is text")
self.assertEqual(result2.reasoning_text, "")
def test_partial_end_tag_in_reasoning_mode(self):
"""
Test partial end tag handling when already in reasoning mode.
"""
detector = BaseReasoningFormatDetector("<think>", "</think>")
# Enter reasoning mode
detector.parse_streaming_increment("<think>")
detector.parse_streaming_increment("some reasoning")
# Send partial end tag
result1 = detector.parse_streaming_increment("</")
self.assertEqual(result1.normal_text, "")
self.assertEqual(result1.reasoning_text, "")
# Complete the end tag with normal text
result2 = detector.parse_streaming_increment("think>normal text")
self.assertEqual(result2.normal_text, "normal text")
# The reasoning text should be empty since buffer was cleared when end tag was processed
self.assertEqual(result2.reasoning_text, "")
def test_multiple_partial_fragments(self):
"""
Test handling of multiple partial fragments that don't match any tokens.
"""
detector = BaseReasoningFormatDetector("<think>", "</think>")
# Send multiple partial fragments
result1 = detector.parse_streaming_increment("<")
self.assertEqual(result1.normal_text, "")
self.assertEqual(result1.reasoning_text, "")
result2 = detector.parse_streaming_increment("/")
self.assertEqual(result2.normal_text, "")
self.assertEqual(result2.reasoning_text, "")
result3 = detector.parse_streaming_increment("random>")
self.assertEqual(result3.normal_text, "</random>")
self.assertEqual(result3.reasoning_text, "")
def test_edge_case_exact_token_match(self):
"""
Test edge case where buffer content exactly matches a token.
"""
detector = BaseReasoningFormatDetector("<think>", "</think>")
# Build up the exact start token character by character
detector.parse_streaming_increment("<")
detector.parse_streaming_increment("t")
detector.parse_streaming_increment("h")
detector.parse_streaming_increment("i")
detector.parse_streaming_increment("n")
result = detector.parse_streaming_increment("k>")
# Should enter reasoning mode
self.assertEqual(result.normal_text, "")
self.assertEqual(result.reasoning_text, "")
self.assertTrue(detector._in_reasoning)
self.assertTrue(detector.stripped_think_start)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment