Unverified Commit 94846416 authored by Song's avatar Song Committed by GitHub
Browse files

[Model] Add step3 vl (#21998)


Signed-off-by: default avataroliveryuan <yuansong@step.ai>
Co-authored-by: default avataroliveryuan <yuansong@step.ai>
parent 207b750e
...@@ -625,6 +625,7 @@ See [this page](generative_models.md) for more information on how to use generat ...@@ -625,6 +625,7 @@ See [this page](generative_models.md) for more information on how to use generat
| `Qwen2_5OmniThinkerForConditionalGeneration` | Qwen2.5-Omni | T + I<sup>E+</sup> + V<sup>E+</sup> + A<sup>+</sup> | `Qwen/Qwen2.5-Omni-7B` | | ✅︎ | ✅︎ | | `Qwen2_5OmniThinkerForConditionalGeneration` | Qwen2.5-Omni | T + I<sup>E+</sup> + V<sup>E+</sup> + A<sup>+</sup> | `Qwen/Qwen2.5-Omni-7B` | | ✅︎ | ✅︎ |
| `SkyworkR1VChatModel` | Skywork-R1V-38B | T + I | `Skywork/Skywork-R1V-38B` | | ✅︎ | ✅︎ | | `SkyworkR1VChatModel` | Skywork-R1V-38B | T + I | `Skywork/Skywork-R1V-38B` | | ✅︎ | ✅︎ |
| `SmolVLMForConditionalGeneration` | SmolVLM2 | T + I | `SmolVLM2-2.2B-Instruct` | ✅︎ | | ✅︎ | | `SmolVLMForConditionalGeneration` | SmolVLM2 | T + I | `SmolVLM2-2.2B-Instruct` | ✅︎ | | ✅︎ |
| `Step3VLForConditionalGeneration` | Step3-VL | T + I<sup>+</sup> | `stepfun-ai/step3` | | ✅︎ | ✅︎ |
| `TarsierForConditionalGeneration` | Tarsier | T + I<sup>E+</sup> | `omni-search/Tarsier-7b`, `omni-search/Tarsier-34b` | | ✅︎ | ✅︎ | | `TarsierForConditionalGeneration` | Tarsier | T + I<sup>E+</sup> | `omni-search/Tarsier-7b`, `omni-search/Tarsier-34b` | | ✅︎ | ✅︎ |
| `Tarsier2ForConditionalGeneration`<sup>^</sup> | Tarsier2 | T + I<sup>E+</sup> + V<sup>E+</sup> | `omni-research/Tarsier2-Recap-7b`, `omni-research/Tarsier2-7b-0115` | | ✅︎ | ✅︎ | | `Tarsier2ForConditionalGeneration`<sup>^</sup> | Tarsier2 | T + I<sup>E+</sup> + V<sup>E+</sup> | `omni-research/Tarsier2-Recap-7b`, `omni-research/Tarsier2-7b-0115` | | ✅︎ | ✅︎ |
......
...@@ -279,6 +279,9 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { ...@@ -279,6 +279,9 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"StableLMEpochForCausalLM": _HfExamplesInfo("stabilityai/stablelm-zephyr-3b"), # noqa: E501 "StableLMEpochForCausalLM": _HfExamplesInfo("stabilityai/stablelm-zephyr-3b"), # noqa: E501
"StableLmForCausalLM": _HfExamplesInfo("stabilityai/stablelm-3b-4e1t"), "StableLmForCausalLM": _HfExamplesInfo("stabilityai/stablelm-3b-4e1t"),
"Starcoder2ForCausalLM": _HfExamplesInfo("bigcode/starcoder2-3b"), "Starcoder2ForCausalLM": _HfExamplesInfo("bigcode/starcoder2-3b"),
"Step3TextForCausalLM": _HfExamplesInfo("stepfun-ai/step3",
trust_remote_code=True,
is_available_online=False),
"SolarForCausalLM": _HfExamplesInfo("upstage/solar-pro-preview-instruct", "SolarForCausalLM": _HfExamplesInfo("upstage/solar-pro-preview-instruct",
trust_remote_code=True), trust_remote_code=True),
"TeleChat2ForCausalLM": _HfExamplesInfo("Tele-AI/TeleChat2-3B", "TeleChat2ForCausalLM": _HfExamplesInfo("Tele-AI/TeleChat2-3B",
...@@ -457,6 +460,9 @@ _MULTIMODAL_EXAMPLE_MODELS = { ...@@ -457,6 +460,9 @@ _MULTIMODAL_EXAMPLE_MODELS = {
"SkyworkR1VChatModel": _HfExamplesInfo("Skywork/Skywork-R1V-38B", "SkyworkR1VChatModel": _HfExamplesInfo("Skywork/Skywork-R1V-38B",
trust_remote_code=True), trust_remote_code=True),
"SmolVLMForConditionalGeneration": _HfExamplesInfo("HuggingFaceTB/SmolVLM2-2.2B-Instruct"), # noqa: E501 "SmolVLMForConditionalGeneration": _HfExamplesInfo("HuggingFaceTB/SmolVLM2-2.2B-Instruct"), # noqa: E501
"Step3VLForConditionalGeneration": _HfExamplesInfo("stepfun-ai/step3",
trust_remote_code=True,
is_available_online=False),
"UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_5-llama-3_2-1b", # noqa: E501 "UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_5-llama-3_2-1b", # noqa: E501
trust_remote_code=True), trust_remote_code=True),
"TarsierForConditionalGeneration": _HfExamplesInfo("omni-research/Tarsier-7b", # noqa: E501 "TarsierForConditionalGeneration": _HfExamplesInfo("omni-research/Tarsier-7b", # noqa: E501
......
...@@ -18,6 +18,7 @@ from .mistral_tool_parser import MistralToolParser ...@@ -18,6 +18,7 @@ from .mistral_tool_parser import MistralToolParser
from .phi4mini_tool_parser import Phi4MiniJsonToolParser from .phi4mini_tool_parser import Phi4MiniJsonToolParser
from .pythonic_tool_parser import PythonicToolParser from .pythonic_tool_parser import PythonicToolParser
from .qwen3coder_tool_parser import Qwen3CoderToolParser from .qwen3coder_tool_parser import Qwen3CoderToolParser
from .step3_tool_parser import Step3ToolParser
from .xlam_tool_parser import xLAMToolParser from .xlam_tool_parser import xLAMToolParser
__all__ = [ __all__ = [
...@@ -40,4 +41,5 @@ __all__ = [ ...@@ -40,4 +41,5 @@ __all__ = [
"HunyuanA13BToolParser", "HunyuanA13BToolParser",
"Glm4MoeModelToolParser", "Glm4MoeModelToolParser",
"Qwen3CoderToolParser", "Qwen3CoderToolParser",
"Step3ToolParser",
] ]
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
import json
from collections.abc import Sequence
from typing import Any, Optional, Union
import regex as re
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaFunctionCall, DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall, ToolCall)
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser, ToolParserManager)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import random_uuid
logger = init_logger(__name__)
@ToolParserManager.register_module(["step3"])
class Step3ToolParser(ToolParser):
"""
Tool parser for a model that uses a specific XML-like format for tool calls.
This version uses a robust, stateful, cursor-based streaming parser and
consolidates tool arguments into a single message.
"""
TOOL_CALLS_BEGIN = "<|tool_calls_begin|>"
TOOL_CALLS_END = "<|tool_calls_end|>"
TOOL_CALL_BEGIN = "<|tool_call_begin|>"
TOOL_CALL_END = "<|tool_call_end|>"
TOOL_SEP = "<|tool_sep|>"
SPECIAL_TOKENS = [
TOOL_CALLS_BEGIN, TOOL_CALLS_END, TOOL_CALL_BEGIN, TOOL_CALL_END
]
def __init__(self, tokenizer: AnyTokenizer):
super().__init__(tokenizer)
self.position = 0
# Explicit state flags for robust streaming
self.tool_block_started = False
self.tool_block_finished = False
def adjust_request(
self, request: ChatCompletionRequest) -> ChatCompletionRequest:
if request.tools and request.tool_choice != 'none':
request.skip_special_tokens = False
return request
@staticmethod
def _parse_steptml_invoke(
action_text: str
) -> tuple[Optional[str], Optional[dict[str, str]]]:
func_name_match = re.search(r'<steptml:invoke name="([^"]+)">',
action_text)
if not func_name_match:
return None, None
func_name = func_name_match.group(1)
params: dict[str, str] = {}
param_matches = re.findall(
r'<steptml:parameter name="([^"]+)">([^<]*)</steptml:parameter>',
action_text)
for name, value in param_matches:
params[name] = value.strip()
return func_name, params
def _cast_arguments(
self,
func_name: str,
params: dict[str, Any],
request: ChatCompletionRequest,
) -> dict[str, Any]:
for tool in request.tools or []:
if tool.function.name == func_name:
schema = tool.function.parameters or {}
properties = schema.get("properties", {})
for key, value in params.items():
if not isinstance(value, str):
continue
prop = properties.get(key, {})
typ = prop.get("type")
if typ == "string":
params[key] = value.strip()
elif typ == "integer":
with contextlib.suppress(ValueError):
params[key] = int(value)
elif typ == "number":
with contextlib.suppress(ValueError):
params[key] = float(value)
elif typ == "boolean":
lower_val = value.lower()
params[key] = lower_val == "true" if lower_val in (
"true", "false") else value
elif typ == "null":
params[key] = None if value.lower(
) == "null" else value
break
return params
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> Union[DeltaMessage, None]:
# The main loop processes the stream from the last known position.
while True:
if self.position >= len(current_text):
return None # We've processed the entire stream.
unprocessed_text = current_text[self.position:]
# STATE: After all tools are done, all subsequent text is content.
if self.tool_block_finished:
self.position = len(current_text)
return DeltaMessage(content=unprocessed_text)
# STATE: Before the tool block has started.
if not self.tool_block_started:
if unprocessed_text.startswith(self.TOOL_CALLS_BEGIN):
self.position += len(self.TOOL_CALLS_BEGIN)
self.tool_block_started = True
continue # Token consumed, re-loop.
start_pos = unprocessed_text.find(self.TOOL_CALLS_BEGIN)
if start_pos == -1:
if self.TOOL_CALLS_BEGIN.startswith(
unprocessed_text.strip()) and unprocessed_text:
return None # It's a prefix, wait.
self.position = len(current_text)
return DeltaMessage(content=unprocessed_text)
else:
content = unprocessed_text[:start_pos]
self.position += len(content)
return DeltaMessage(content=content)
# STATE: Inside the main tool block.
offset = len(unprocessed_text) - len(unprocessed_text.lstrip())
unprocessed_text = unprocessed_text.lstrip()
self.position += offset
if unprocessed_text.startswith(self.TOOL_CALLS_END):
self.position += len(self.TOOL_CALLS_END)
self.tool_block_finished = True
self.current_tool_id = -1
continue
# Check if we are between tool calls.
tool_finished = (
self.current_tool_id != -1 and
self.prev_tool_call_arr[self.current_tool_id].get("finished"))
if self.current_tool_id == -1 or tool_finished:
if unprocessed_text.startswith(self.TOOL_CALL_BEGIN):
self.position += len(self.TOOL_CALL_BEGIN)
if self.current_tool_id == -1:
self.current_tool_id = 0
else:
self.current_tool_id += 1
self.current_tool_name_sent = False
while len(self.prev_tool_call_arr) <= self.current_tool_id:
self.prev_tool_call_arr.append({})
self.prev_tool_call_arr[
self.current_tool_id]["finished"] = False
continue
if self.TOOL_CALL_BEGIN.startswith(unprocessed_text):
return None
# STATE: Parsing an active tool call.
if self.current_tool_id != -1 and not self.prev_tool_call_arr[
self.current_tool_id].get("finished", False):
end_tool_pos = unprocessed_text.find(self.TOOL_CALL_END)
if end_tool_pos == -1:
tool_body = unprocessed_text
else:
tool_body = unprocessed_text[:end_tool_pos]
if end_tool_pos == -1 and self.TOOL_CALL_END.startswith(
tool_body):
return None
function_name, arguments = self._parse_steptml_invoke(
tool_body)
if not function_name:
return None
tool_call_arr = {
"name": function_name,
"parameters": arguments or {}
}
# Send the function name as soon as it's parsed.
if not self.current_tool_name_sent:
self.current_tool_name_sent = True
self.prev_tool_call_arr[self.current_tool_id].update(
tool_call_arr)
return DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
type="function",
id=f"chatcmpl-tool-{random_uuid()}",
function=DeltaFunctionCall(
name=function_name))
])
# Update our internal state with the latest parsed arguments.
self.prev_tool_call_arr[
self.current_tool_id].update( # noqa: E501
tool_call_arr)
# Only send arguments when the tool call is complete.
if end_tool_pos != -1:
self.position += end_tool_pos + len(self.TOOL_CALL_END)
self.prev_tool_call_arr[
self.current_tool_id]["finished"] = True
final_args = self._cast_arguments(
function_name,
tool_call_arr.get("parameters", {}), # type: ignore
request)
if final_args:
final_args_json = json.dumps(final_args,
ensure_ascii=False)
return DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=final_args_json))
])
# If tool is not finished, return None to wait for more tokens.
return None
return None
def extract_tool_calls(
self,
model_output: str,
request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:
if self.TOOL_CALLS_BEGIN not in model_output:
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)
pre_text, rest = model_output.split(self.TOOL_CALLS_BEGIN, 1)
if self.TOOL_CALLS_END not in rest:
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)
tool_block, post_text = rest.split(self.TOOL_CALLS_END, 1)
content = (pre_text + post_text).strip()
tool_calls: list[ToolCall] = []
call_parts = tool_block.split(self.TOOL_CALL_BEGIN)
for part in call_parts:
if not part or self.TOOL_CALL_END not in part:
continue
call_content = part.split(self.TOOL_CALL_END, 1)[0]
if self.TOOL_SEP not in call_content:
continue
type_part, invoke_part = call_content.split(self.TOOL_SEP, 1)
if type_part.strip() != "function":
continue
function_name, params_dict = self._parse_steptml_invoke(
invoke_part)
if function_name and params_dict is not None:
params_dict = self._cast_arguments(function_name, params_dict,
request)
params_str = json.dumps(params_dict, ensure_ascii=False)
tool_calls.append(
ToolCall(function=FunctionCall(name=function_name,
arguments=params_str)))
if tool_calls:
return ExtractedToolCallInformation(
tools_called=True,
tool_calls=tool_calls,
content=content if content else None)
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)
...@@ -129,6 +129,7 @@ _TEXT_GENERATION_MODELS = { ...@@ -129,6 +129,7 @@ _TEXT_GENERATION_MODELS = {
"Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"), "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),
"Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"), "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),
"RWForCausalLM": ("falcon", "FalconForCausalLM"), "RWForCausalLM": ("falcon", "FalconForCausalLM"),
"Step3TextForCausalLM": ("step3_text", "Step3TextForCausalLM"),
"StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"), "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
"StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"), "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
"Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"), "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
...@@ -238,6 +239,7 @@ _MULTIMODAL_MODELS = { ...@@ -238,6 +239,7 @@ _MULTIMODAL_MODELS = {
"Qwen2_5OmniModel": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"), # noqa: E501 "Qwen2_5OmniModel": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"), # noqa: E501
"Qwen2_5OmniForConditionalGeneration": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"), # noqa: E501 "Qwen2_5OmniForConditionalGeneration": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"), # noqa: E501
"UltravoxModel": ("ultravox", "UltravoxModel"), "UltravoxModel": ("ultravox", "UltravoxModel"),
"Step3VLForConditionalGeneration": ("step3_vl", "Step3VLForConditionalGeneration"), # noqa: E501
"TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"), # noqa: E501 "TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"), # noqa: E501
"Tarsier2ForConditionalGeneration": ("qwen2_vl", "Tarsier2ForConditionalGeneration"), # noqa: E501 "Tarsier2ForConditionalGeneration": ("qwen2_vl", "Tarsier2ForConditionalGeneration"), # noqa: E501
"VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"), # noqa: E501 "VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"), # noqa: E501
......
This diff is collapsed.
This diff is collapsed.
...@@ -8,6 +8,7 @@ from .granite_reasoning_parser import GraniteReasoningParser ...@@ -8,6 +8,7 @@ from .granite_reasoning_parser import GraniteReasoningParser
from .hunyuan_a13b_reasoning_parser import HunyuanA13BReasoningParser from .hunyuan_a13b_reasoning_parser import HunyuanA13BReasoningParser
from .mistral_reasoning_parser import MistralReasoningParser from .mistral_reasoning_parser import MistralReasoningParser
from .qwen3_reasoning_parser import Qwen3ReasoningParser from .qwen3_reasoning_parser import Qwen3ReasoningParser
from .step3_reasoning_parser import Step3ReasoningParser
__all__ = [ __all__ = [
"ReasoningParser", "ReasoningParser",
...@@ -18,4 +19,5 @@ __all__ = [ ...@@ -18,4 +19,5 @@ __all__ = [
"Qwen3ReasoningParser", "Qwen3ReasoningParser",
"Glm4MoeModelReasoningParser", "Glm4MoeModelReasoningParser",
"MistralReasoningParser", "MistralReasoningParser",
"Step3ReasoningParser",
] ]
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence
from typing import Optional, Union
import regex as re
from transformers import PreTrainedTokenizerBase
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaMessage)
from vllm.logger import init_logger
from vllm.reasoning import ReasoningParser, ReasoningParserManager
logger = init_logger(__name__)
@ReasoningParserManager.register_module("step3")
class Step3ReasoningParser(ReasoningParser):
"""
Reasoning parser for Step3 model.
The Step3 model uses </think> token to denote the end of reasoning
text. This parser extracts all content before </think> as reasoning content.
"""
def __init__(self, tokenizer: PreTrainedTokenizerBase):
super().__init__(tokenizer)
self.think_end_token = "</think>"
self.reasoning_regex = re.compile(rf"(.*?){self.think_end_token}",
re.DOTALL)
if not self.model_tokenizer:
raise ValueError(
"The model tokenizer must be passed to the ReasoningParser "
"constructor during construction.")
self.think_end_token_id = self.vocab.get(self.think_end_token)
if self.think_end_token_id is None:
raise RuntimeError(
"Step3 reasoning parser could not locate think end "
"token in the tokenizer!")
def extract_reasoning_content_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
) -> Union[DeltaMessage, None]:
"""
Extract reasoning content from a delta message.
Handles streaming output where previous + delta = current.
Uses token IDs for faster processing.
For text "abc</think>xyz":
- 'abc' goes to reasoning_content
- 'xyz' goes to content
"""
# Skip single special token
if len(delta_token_ids
) == 1 and delta_token_ids[0] == self.think_end_token_id:
return None
if self.think_end_token_id in delta_token_ids:
# </think> in delta, extract reasoning content and remaining content
end_index = delta_text.find(self.think_end_token)
reasoning_content = delta_text[:end_index]
content = delta_text[end_index + len(self.think_end_token):]
return DeltaMessage(reasoning_content=reasoning_content,
content=content if content else None)
elif self.think_end_token_id in previous_token_ids:
# </think> already seen in previous text, everything is content
return DeltaMessage(content=delta_text)
else:
# No </think> seen yet, everything is reasoning
return DeltaMessage(reasoning_content=delta_text)
def extract_reasoning_content(
self, model_output: str, request: ChatCompletionRequest
) -> tuple[Optional[str], Optional[str]]:
# Check if the model output contains the </think> token
if self.think_end_token not in model_output:
# If no </think> token, everything is reasoning content
return model_output, None
else:
# Find the first occurrence of </think>
end_index = model_output.find(self.think_end_token)
reasoning_content = model_output[:end_index]
# Content after </think> token
content = model_output[end_index + len(self.think_end_token):]
if len(content) == 0:
content = None
return reasoning_content, content
def is_reasoning_end(self, input_ids: list[int]) -> bool:
return self.think_end_token_id in input_ids
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
if self.think_end_token_id not in input_ids[:-1]:
return []
else:
return input_ids[input_ids.index(self.think_end_token_id) + 1:]
...@@ -35,7 +35,8 @@ from vllm.transformers_utils.configs import (ChatGLMConfig, DeepseekVLV2Config, ...@@ -35,7 +35,8 @@ from vllm.transformers_utils.configs import (ChatGLMConfig, DeepseekVLV2Config,
MllamaConfig, MLPSpeculatorConfig, MllamaConfig, MLPSpeculatorConfig,
Nemotron_Nano_VL_Config, Nemotron_Nano_VL_Config,
NemotronConfig, NVLM_D_Config, NemotronConfig, NVLM_D_Config,
RWConfig, UltravoxConfig) RWConfig, Step3TextConfig,
Step3VLConfig, UltravoxConfig)
# yapf: enable # yapf: enable
from vllm.transformers_utils.configs.mistral import adapt_config_dict from vllm.transformers_utils.configs.mistral import adapt_config_dict
from vllm.transformers_utils.utils import check_gguf_file from vllm.transformers_utils.utils import check_gguf_file
...@@ -83,6 +84,8 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = { ...@@ -83,6 +84,8 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = {
"nemotron": NemotronConfig, "nemotron": NemotronConfig,
"NVLM_D": NVLM_D_Config, "NVLM_D": NVLM_D_Config,
"ultravox": UltravoxConfig, "ultravox": UltravoxConfig,
"step3_vl": Step3VLConfig,
"step3_text": Step3TextConfig,
**_CONFIG_REGISTRY_OVERRIDE_HF **_CONFIG_REGISTRY_OVERRIDE_HF
} }
......
...@@ -24,6 +24,9 @@ from vllm.transformers_utils.configs.nemotron import NemotronConfig ...@@ -24,6 +24,9 @@ from vllm.transformers_utils.configs.nemotron import NemotronConfig
from vllm.transformers_utils.configs.nemotron_h import NemotronHConfig from vllm.transformers_utils.configs.nemotron_h import NemotronHConfig
from vllm.transformers_utils.configs.nemotron_vl import Nemotron_Nano_VL_Config from vllm.transformers_utils.configs.nemotron_vl import Nemotron_Nano_VL_Config
from vllm.transformers_utils.configs.nvlm_d import NVLM_D_Config from vllm.transformers_utils.configs.nvlm_d import NVLM_D_Config
from vllm.transformers_utils.configs.step3_vl import (Step3TextConfig,
Step3VisionEncoderConfig,
Step3VLConfig)
from vllm.transformers_utils.configs.ultravox import UltravoxConfig from vllm.transformers_utils.configs.ultravox import UltravoxConfig
__all__ = [ __all__ = [
...@@ -42,4 +45,7 @@ __all__ = [ ...@@ -42,4 +45,7 @@ __all__ = [
"Nemotron_Nano_VL_Config", "Nemotron_Nano_VL_Config",
"NVLM_D_Config", "NVLM_D_Config",
"UltravoxConfig", "UltravoxConfig",
"Step3VLConfig",
"Step3VisionEncoderConfig",
"Step3TextConfig",
] ]
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Optional, Union
from transformers.configuration_utils import PretrainedConfig
class Step3VisionEncoderConfig(PretrainedConfig):
model_type = "step3_vision_encoder"
def __init__(
self,
hidden_size=1792,
intermediate_size=3072,
output_hidden_size=4096,
num_hidden_layers=63,
num_attention_heads=16,
num_channels=3,
image_size=728,
patch_size=14,
hidden_act="quick_gelu",
layer_norm_eps=1e-5,
**kwargs,
):
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.output_hidden_size = output_hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_channels = num_channels
self.patch_size = patch_size
self.image_size = image_size
self.layer_norm_eps = layer_norm_eps
self.hidden_act = hidden_act
super().__init__(**kwargs)
class Step3TextConfig(PretrainedConfig):
model_type = "step3_text"
architectures = ["Step3TextForCausalLM"]
def __init__(
self,
hidden_size: int = 7168,
intermediate_size: int = 18432,
num_attention_heads: int = 64,
num_attention_groups: int = 1,
num_hidden_layers: int = 61,
max_seq_len: int = 65536,
vocab_size: int = 128815,
rms_norm_eps: float = 1e-5,
moe_intermediate_size: int = 5120,
moe_num_experts: int = 48,
moe_top_k: int = 3,
rope_theta: float = 500000,
rope_scaling: Optional[dict[str, Any]] = None,
max_position_embedding: int = 65536,
share_expert_dim: int = 5120,
share_q_dim: int = 2048,
head_dim: int = 256,
norm_expert_weight: bool = False,
moe_layers_enum: tuple[int,
...] = (4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14,
15, 16, 17, 18, 19, 20, 21, 22, 23, 24,
25, 26, 27, 28, 29, 30, 31, 32, 33, 34,
35, 36, 37, 38, 39, 40, 41, 42, 43, 44,
45, 46, 47, 48, 49, 50, 51, 52, 53, 54,
55, 56, 57, 58, 59),
**kwargs,
) -> None:
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_attention_heads = num_attention_heads
self.num_attention_groups = num_attention_groups
self.num_hidden_layers = num_hidden_layers
self.max_seq_len = max_seq_len
self.vocab_size = vocab_size
self.rms_norm_eps = rms_norm_eps
self.moe_intermediate_size = moe_intermediate_size
self.moe_num_experts = moe_num_experts
self.moe_top_k = moe_top_k
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.max_position_embedding = max_position_embedding
self.share_expert_dim = share_expert_dim
self.share_q_dim = share_q_dim
self.head_dim = head_dim
self.norm_expert_weight = norm_expert_weight
self.moe_layers_enum = moe_layers_enum
super().__init__(**kwargs)
class Step3VLConfig(PretrainedConfig):
model_type = "step3_vl"
def __init__(
self,
vision_config: Optional[Union[dict, Step3VisionEncoderConfig]] = None,
text_config: Optional[Union[dict, Step3TextConfig]] = None,
understand_projector_stride: int = 1,
projector_bias: bool = True,
image_token_id: int = 128001,
**kwargs,
) -> None:
if vision_config is None:
vision_config = Step3VisionEncoderConfig()
elif isinstance(vision_config, dict):
vision_config = Step3VisionEncoderConfig(**vision_config)
self.vision_config = vision_config
if text_config is None:
text_config = Step3TextConfig()
elif isinstance(text_config, dict):
text_config = Step3TextConfig(**text_config)
self.text_config = text_config
self.understand_projector_stride = understand_projector_stride
self.projector_bias = projector_bias
self.hidden_size = text_config.hidden_size
self.image_token_id = image_token_id
super().__init__(**kwargs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment