Commit a8c6fcf6 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.9.2-dev' of http://10.16.6.30/dcutoolkit/deeplearing/vllm into v0.9.2-dev

parents 0480314d 66540380
......@@ -598,6 +598,7 @@ Specified using `--task generate`.
| `Qwen2_5OmniThinkerForConditionalGeneration` | Qwen2.5-Omni | T + I<sup>E+</sup> + V<sup>E+</sup> + A<sup>+</sup> | `Qwen/Qwen2.5-Omni-7B` | | ✅︎ | ✅︎\* |
| `SkyworkR1VChatModel` | Skywork-R1V-38B | T + I | `Skywork/Skywork-R1V-38B` | | ✅︎ | ✅︎ |
| `SmolVLMForConditionalGeneration` | SmolVLM2 | T + I | `SmolVLM2-2.2B-Instruct` | ✅︎ | | ✅︎ |
| `Step3VLForConditionalGeneration` | Step3-VL | T + I<sup>+</sup> | `stepfun-ai/step3` | | ✅︎ | ✅︎ |
| `TarsierForConditionalGeneration` | Tarsier | T + I<sup>E+</sup> | `omni-search/Tarsier-7b`,`omni-search/Tarsier-34b` | | ✅︎ | ✅︎ |
| `Tarsier2ForConditionalGeneration`<sup>^</sup> | Tarsier2 | T + I<sup>E+</sup> + V<sup>E+</sup> | `omni-research/Tarsier2-Recap-7b`,`omni-research/Tarsier2-7b-0115` | | ✅︎ | ✅︎ |
......
......@@ -266,6 +266,9 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"StableLMEpochForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix,"stabilityai/stablelm-zephyr-3b")), # noqa: E501
"StableLmForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix,"stabilityai/stablelm-3b-4e1t")),
"Starcoder2ForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix,"bigcode/starcoder2-3b")),
"Step3TextForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "stepfun-ai/step3"),
trust_remote_code=True,
is_available_online=False),
"SolarForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix,"upstage/solar-pro-preview-instruct")),
"TeleChat2ForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix,"Tele-AI/TeleChat2-3B"),
trust_remote_code=True),
......@@ -423,6 +426,9 @@ _MULTIMODAL_EXAMPLE_MODELS = {
"Qwen2_5OmniForConditionalGeneration": _HfExamplesInfo(os.path.join(models_path_prefix, "Qwen/Qwen2.5-Omni-7B-AWQ")), # noqa: E501
"SkyworkR1VChatModel": _HfExamplesInfo(os.path.join(models_path_prefix, "Skywork/Skywork-R1V-38B")),
"SmolVLMForConditionalGeneration": _HfExamplesInfo(os.path.join(models_path_prefix, "HuggingFaceTB/SmolVLM2-2.2B-Instruct")), # noqa: E501
"Step3VLForConditionalGeneration": _HfExamplesInfo(os.path.join(models_path_prefix, "stepfun-ai/step3"),
trust_remote_code=True,
is_available_online=False),
"UltravoxModel": _HfExamplesInfo(os.path.join(models_path_prefix, "fixie-ai/ultravox-v0_5-llama-3_2-1b"), # noqa: E501
trust_remote_code=True),
"TarsierForConditionalGeneration": _HfExamplesInfo("omni-research/Tarsier-7b", # noqa: E501
......
......@@ -15,6 +15,7 @@ from .minimax_tool_parser import MinimaxToolParser
from .mistral_tool_parser import MistralToolParser
from .phi4mini_tool_parser import Phi4MiniJsonToolParser
from .pythonic_tool_parser import PythonicToolParser
from .step3_tool_parser import Step3ToolParser
from .xlam_tool_parser import xLAMToolParser
__all__ = [
......@@ -31,6 +32,7 @@ __all__ = [
"PythonicToolParser",
"Phi4MiniJsonToolParser",
"DeepSeekV3ToolParser",
"Step3ToolParser",
"xLAMToolParser",
"MinimaxToolParser",
"Glm4MoeModelToolParser",
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
import json
from collections.abc import Sequence
from typing import Any, Optional, Union
import regex as re
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaFunctionCall, DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall, ToolCall)
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser, ToolParserManager)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import random_uuid
logger = init_logger(__name__)
@ToolParserManager.register_module(["step3"])
class Step3ToolParser(ToolParser):
"""
Tool parser for a model that uses a specific XML-like format for tool calls.
This version uses a robust, stateful, cursor-based streaming parser and
consolidates tool arguments into a single message.
"""
TOOL_CALLS_BEGIN = "<|tool_calls_begin|>"
TOOL_CALLS_END = "<|tool_calls_end|>"
TOOL_CALL_BEGIN = "<|tool_call_begin|>"
TOOL_CALL_END = "<|tool_call_end|>"
TOOL_SEP = "<|tool_sep|>"
SPECIAL_TOKENS = [
TOOL_CALLS_BEGIN, TOOL_CALLS_END, TOOL_CALL_BEGIN, TOOL_CALL_END
]
def __init__(self, tokenizer: AnyTokenizer):
super().__init__(tokenizer)
self.position = 0
# Explicit state flags for robust streaming
self.tool_block_started = False
self.tool_block_finished = False
def adjust_request(
self, request: ChatCompletionRequest) -> ChatCompletionRequest:
if request.tools and request.tool_choice != 'none':
request.skip_special_tokens = False
return request
@staticmethod
def _parse_steptml_invoke(
action_text: str
) -> tuple[Optional[str], Optional[dict[str, str]]]:
func_name_match = re.search(r'<steptml:invoke name="([^"]+)">',
action_text)
if not func_name_match:
return None, None
func_name = func_name_match.group(1)
params: dict[str, str] = {}
param_matches = re.findall(
r'<steptml:parameter name="([^"]+)">([^<]*)</steptml:parameter>',
action_text)
for name, value in param_matches:
params[name] = value.strip()
return func_name, params
def _cast_arguments(
self,
func_name: str,
params: dict[str, Any],
request: ChatCompletionRequest,
) -> dict[str, Any]:
for tool in request.tools or []:
if tool.function.name == func_name:
schema = tool.function.parameters or {}
properties = schema.get("properties", {})
for key, value in params.items():
if not isinstance(value, str):
continue
prop = properties.get(key, {})
typ = prop.get("type")
if typ == "string":
params[key] = value.strip()
elif typ == "integer":
with contextlib.suppress(ValueError):
params[key] = int(value)
elif typ == "number":
with contextlib.suppress(ValueError):
params[key] = float(value)
elif typ == "boolean":
lower_val = value.lower()
params[key] = lower_val == "true" if lower_val in (
"true", "false") else value
elif typ == "null":
params[key] = None if value.lower(
) == "null" else value
break
return params
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> Union[DeltaMessage, None]:
# The main loop processes the stream from the last known position.
while True:
if self.position >= len(current_text):
return None # We've processed the entire stream.
unprocessed_text = current_text[self.position:]
# STATE: After all tools are done, all subsequent text is content.
if self.tool_block_finished:
self.position = len(current_text)
return DeltaMessage(content=unprocessed_text)
# STATE: Before the tool block has started.
if not self.tool_block_started:
if unprocessed_text.startswith(self.TOOL_CALLS_BEGIN):
self.position += len(self.TOOL_CALLS_BEGIN)
self.tool_block_started = True
continue # Token consumed, re-loop.
start_pos = unprocessed_text.find(self.TOOL_CALLS_BEGIN)
if start_pos == -1:
if self.TOOL_CALLS_BEGIN.startswith(
unprocessed_text.strip()) and unprocessed_text:
return None # It's a prefix, wait.
self.position = len(current_text)
return DeltaMessage(content=unprocessed_text)
else:
content = unprocessed_text[:start_pos]
self.position += len(content)
return DeltaMessage(content=content)
# STATE: Inside the main tool block.
offset = len(unprocessed_text) - len(unprocessed_text.lstrip())
unprocessed_text = unprocessed_text.lstrip()
self.position += offset
if unprocessed_text.startswith(self.TOOL_CALLS_END):
self.position += len(self.TOOL_CALLS_END)
self.tool_block_finished = True
self.current_tool_id = -1
continue
# Check if we are between tool calls.
tool_finished = (
self.current_tool_id != -1 and
self.prev_tool_call_arr[self.current_tool_id].get("finished"))
if self.current_tool_id == -1 or tool_finished:
if unprocessed_text.startswith(self.TOOL_CALL_BEGIN):
self.position += len(self.TOOL_CALL_BEGIN)
if self.current_tool_id == -1:
self.current_tool_id = 0
else:
self.current_tool_id += 1
self.current_tool_name_sent = False
while len(self.prev_tool_call_arr) <= self.current_tool_id:
self.prev_tool_call_arr.append({})
self.prev_tool_call_arr[
self.current_tool_id]["finished"] = False
continue
if self.TOOL_CALL_BEGIN.startswith(unprocessed_text):
return None
# STATE: Parsing an active tool call.
if self.current_tool_id != -1 and not self.prev_tool_call_arr[
self.current_tool_id].get("finished", False):
end_tool_pos = unprocessed_text.find(self.TOOL_CALL_END)
if end_tool_pos == -1:
tool_body = unprocessed_text
else:
tool_body = unprocessed_text[:end_tool_pos]
if end_tool_pos == -1 and self.TOOL_CALL_END.startswith(
tool_body):
return None
function_name, arguments = self._parse_steptml_invoke(
tool_body)
if not function_name:
return None
tool_call_arr = {
"name": function_name,
"parameters": arguments or {}
}
# Send the function name as soon as it's parsed.
if not self.current_tool_name_sent:
self.current_tool_name_sent = True
self.prev_tool_call_arr[self.current_tool_id].update(
tool_call_arr)
return DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
type="function",
id=f"chatcmpl-tool-{random_uuid()}",
function=DeltaFunctionCall(
name=function_name))
])
# Update our internal state with the latest parsed arguments.
self.prev_tool_call_arr[
self.current_tool_id].update( # noqa: E501
tool_call_arr)
# Only send arguments when the tool call is complete.
if end_tool_pos != -1:
self.position += end_tool_pos + len(self.TOOL_CALL_END)
self.prev_tool_call_arr[
self.current_tool_id]["finished"] = True
final_args = self._cast_arguments(
function_name,
tool_call_arr.get("parameters", {}), # type: ignore
request)
if final_args:
final_args_json = json.dumps(final_args,
ensure_ascii=False)
return DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=final_args_json))
])
# If tool is not finished, return None to wait for more tokens.
return None
return None
def extract_tool_calls(
self,
model_output: str,
request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:
if self.TOOL_CALLS_BEGIN not in model_output:
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)
pre_text, rest = model_output.split(self.TOOL_CALLS_BEGIN, 1)
if self.TOOL_CALLS_END not in rest:
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)
tool_block, post_text = rest.split(self.TOOL_CALLS_END, 1)
content = (pre_text + post_text).strip()
tool_calls: list[ToolCall] = []
call_parts = tool_block.split(self.TOOL_CALL_BEGIN)
for part in call_parts:
if not part or self.TOOL_CALL_END not in part:
continue
call_content = part.split(self.TOOL_CALL_END, 1)[0]
if self.TOOL_SEP not in call_content:
continue
type_part, invoke_part = call_content.split(self.TOOL_SEP, 1)
if type_part.strip() != "function":
continue
function_name, params_dict = self._parse_steptml_invoke(
invoke_part)
if function_name and params_dict is not None:
params_dict = self._cast_arguments(function_name, params_dict,
request)
params_str = json.dumps(params_dict, ensure_ascii=False)
tool_calls.append(
ToolCall(function=FunctionCall(name=function_name,
arguments=params_str)))
if tool_calls:
return ExtractedToolCallInformation(
tools_called=True,
tool_calls=tool_calls,
content=content if content else None)
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)
\ No newline at end of file
......@@ -120,6 +120,7 @@ _TEXT_GENERATION_MODELS = {
"Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),
"Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),
"RWForCausalLM": ("falcon", "FalconForCausalLM"),
"Step3TextForCausalLM": ("step3_text", "Step3TextForCausalLM"),
"StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
"StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
"Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
......@@ -228,6 +229,7 @@ _MULTIMODAL_MODELS = {
"Qwen2_5OmniModel": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"), # noqa: E501
"Qwen2_5OmniForConditionalGeneration": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"), # noqa: E501
"UltravoxModel": ("ultravox", "UltravoxModel"),
"Step3VLForConditionalGeneration": ("step3_vl", "Step3VLForConditionalGeneration"), # noqa: E501
"Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
"TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"), # noqa: E501
"Tarsier2ForConditionalGeneration": ("qwen2_vl", "Tarsier2ForConditionalGeneration"), # noqa: E501
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Inference-only Jurassic model."""
from collections.abc import Iterable
from typing import Any, Optional
import torch
from torch import nn
from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.distributed import (get_pp_group,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP
from .utils import (PPMissingLayer, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
logger = init_logger(__name__)
class FusedMoEBlock(nn.Module):
def __init__(self,
config: ModelConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
if self.tp_size > config.moe_num_experts:
raise ValueError(
f"Tensor parallel size {self.tp_size} is greater than "
f"the number of experts {config.moe_num_experts}.")
self.experts = FusedMoE(num_experts=config.moe_num_experts,
top_k=config.moe_top_k,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_expert_weight,
quant_config=quant_config,
prefix=f"{prefix}.experts")
self.gate = ReplicatedLinear(config.hidden_size,
config.moe_num_experts,
bias=False,
quant_config=None,
prefix=f"{prefix}.gate")
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
orig_shape = hidden_states.shape
hidden_dim = hidden_states.shape[-1]
hidden_states = hidden_states.view(-1, hidden_dim)
router_logits, _ = self.gate(hidden_states)
final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits)
if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)
return final_hidden_states.view(orig_shape)
class Step3TextMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj")
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.down_proj")
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
self.act_fn = SiluAndMul()
self.hidden_size = hidden_size
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
gate_up, _ = self.gate_up_proj(hidden_states)
intermediate_act = self.act_fn(gate_up)
output, _ = self.down_proj(intermediate_act)
return output
class Step3TextAttention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
norm_eps: float,
rope_theta: int,
share_q_dim: Optional[int] = None,
rope_scaling: Optional[dict[str, Any]] = None,
max_position_embedding: int = 8192,
head_dim: int = 256,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
if num_kv_heads != 1:
raise ValueError(f"Step3TextAttention num_kv_heads must be 1, "
f"but got {num_kv_heads}.")
self.num_kv_heads = num_kv_heads
self.head_dim = head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.q_size = share_q_dim if share_q_dim else self.head_dim
self.qkv_proj = ReplicatedLinear(
hidden_size,
self.q_size + self.kv_size * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
self.inter_norm = RMSNorm(self.q_size, eps=norm_eps)
self.wq = ColumnParallelLinear(
self.q_size,
self.head_dim * self.total_num_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.wq",
)
self.rotary_emb = get_rope(self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embedding,
base=rope_theta,
rope_scaling=rope_scaling)
scaling = self.head_dim**-0.5
self.attn = Attention(self.num_heads,
self.head_dim,
scaling,
self.num_kv_heads,
cache_config=cache_config,
prefix=f"{prefix}.attn")
def forward(self, positions: torch.Tensor,
hidden_states: torch.Tensor) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q = self.inter_norm(q)
q = self.wq(q)[0]
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)
residual, _ = self.o_proj(attn_output)
return residual
class Step3TextDecoderLayer(nn.Module):
def __init__(self,
config: ModelConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "") -> None:
super().__init__()
config = config.hf_config
self.hidden_size = config.hidden_size
rope_scaling = getattr(config, "rope_scaling", None)
self.self_attn = Step3TextAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=1,
cache_config=cache_config,
quant_config=quant_config,
norm_eps=config.rms_norm_eps,
max_position_embedding=config.max_position_embedding,
head_dim=config.head_dim,
share_q_dim=config.share_q_dim,
rope_theta=config.rope_theta,
rope_scaling=rope_scaling,
prefix=f"{prefix}.self_attn")
layer_idx = int(prefix.split("layers.")[1].split(".")[0])
moe_layers_enum = getattr(config, "moe_layers_enum", None)
if moe_layers_enum is not None:
moe_layers_idx = [
int(i) for i in moe_layers_enum.strip().split(',')
]
else:
# Default to 1dense.
moe_layers_idx = [i for i in range(1, config.num_hidden_layers)]
if layer_idx in moe_layers_idx:
self.moe = FusedMoEBlock(config=config,
quant_config=quant_config,
prefix=f"{prefix}.moe")
self.share_expert = Step3TextMLP(
hidden_size=self.hidden_size,
intermediate_size=config.share_expert_dim,
hidden_act="silu",
quant_config=quant_config,
prefix=f"{prefix}.share_expert")
self.use_moe = True
else:
self.mlp = Step3TextMLP(hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act="silu",
quant_config=quant_config,
prefix=f"{prefix}.mlp")
self.use_moe = False
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward(
self, positions: torch.Tensor, hidden_states: torch.Tensor,
residual: Optional[torch.Tensor]
) -> tuple[torch.Tensor, torch.Tensor]:
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
)
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
if self.use_moe:
share_output = self.share_expert(hidden_states)
moe_output = self.moe(hidden_states)
hidden_states = share_output + moe_output
else:
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
@support_torch_compile
class Step3TextModel(nn.Module):
def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.vocab_size = config.vocab_size
self.config = config
if get_pp_group().is_first_rank or (config.tie_word_embeddings
and get_pp_group().is_last_rank):
self.embed_tokens = VocabParallelEmbedding(
self.vocab_size,
config.hidden_size,
)
else:
self.embed_tokens = PPMissingLayer()
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: Step3TextDecoderLayer(config=vllm_config.
model_config,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix),
prefix=f"{prefix}.layers",
)
if get_pp_group().is_last_rank:
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
else:
self.norm = PPMissingLayer()
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(["hidden_states"],
config.hidden_size))
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(positions, hidden_states, residual)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual,
})
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class Step3TextForCausalLM(nn.Module, SupportsPP):
def __init__(
self,
*,
vllm_config: VllmConfig,
prefix: str = "",
):
super().__init__()
config = vllm_config.model_config.hf_config
lora_config = vllm_config.lora_config
self.config = config
self.vllm_config = vllm_config
self.model = Step3TextModel(vllm_config=vllm_config, prefix=prefix)
if get_pp_group().is_last_rank:
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE
if not lora_config else lora_config.lora_vocab_padding_size,
)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size)
self.sampler = get_sampler()
else:
self.lm_head = PPMissingLayer()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward(self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None):
hidden_states = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
qkv_params_mapping = [
# (param_name, shard_name, relative_start_idx, relative_end_idx)
(".qkv_proj", ".q_proj", 0, self.config.share_q_dim /
(self.config.share_q_dim + self.config.head_dim * 2)),
(".qkv_proj", ".k_proj", self.config.share_q_dim /
(self.config.share_q_dim + self.config.head_dim * 2),
(self.config.share_q_dim + self.config.head_dim) /
(self.config.share_q_dim + self.config.head_dim * 2)),
(".qkv_proj", ".v_proj",
(self.config.share_q_dim + self.config.head_dim) /
(self.config.share_q_dim + self.config.head_dim * 2),
(self.config.share_q_dim + self.config.head_dim * 2) /
(self.config.share_q_dim + self.config.head_dim * 2)),
]
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
expert_params_mapping = [
(".moe.experts.w13_weight", ".moe.gate_proj.weight", "w1"),
(".moe.experts.w13_weight", ".moe.up_proj.weight", "w3"),
(".moe.experts.w2_weight", ".moe.down_proj.weight", "w2")
]
disable_moe_stacked_params = [
data[1] for data in expert_params_mapping
]
for name, loaded_weight in weights:
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
if any(disable_moe_stacked_param in name
for disable_moe_stacked_param in
disable_moe_stacked_params):
continue
name = name.replace(weight_name, param_name)
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
loaded_params.add(name)
break
else:
for mapping in expert_params_mapping:
param_name, weight_name, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
# Skip loading extra bias for GPTQ models.
if ((name.endswith(".bias") or name.endswith("_bias"))
and name not in params_dict):
continue
param = params_dict[name]
weight_loader = param.weight_loader
for expert_id in range(loaded_weight.shape[0]):
loaded_weight_expert = loaded_weight[expert_id]
weight_loader(param,
loaded_weight_expert,
name,
shard_id=shard_id,
expert_id=expert_id)
loaded_params.add(name)
break
else:
for (param_name, weight_name, start_idx,
end_idx) in qkv_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
dim = param.shape[param.output_dim]
begin_idx = int(start_idx * dim)
end_idx = int(end_idx * dim)
param_slice = param.narrow(param.output_dim, begin_idx,
end_idx - begin_idx)
param_slice.copy_(loaded_weight)
loaded_params.add(name)
break
else:
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
\ No newline at end of file
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property
from itertools import product
from math import ceil, sqrt
from typing import Any, Literal, Optional, TypedDict, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from transformers import BatchFeature, PretrainedConfig, TensorType
from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargs, NestedTensors)
from vllm.multimodal.parse import ImageSize, MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptUpdate, PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs import Step3VisionEncoderConfig
from vllm.transformers_utils.tokenizer import AnyTokenizer
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings)
class Step3VLImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
pixel_values: torch.Tensor
patch_pixel_values: Optional[torch.Tensor]
num_patches: list[int]
class Step3VLImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
image_embeds: torch.Tensor
Step3VLImageInputs = Union[Step3VLImagePixelInputs,
Step3VLImageEmbeddingInputs]
ImageWithPatches = tuple[Image.Image, list[Image.Image], list[int] | None]
MAX_IMAGE_SIZE: int = 3024
class Step3VisionProcessor:
def __init__(self, size, interpolation_mode="bicubic", patch_size=None):
mean = [0.48145466, 0.4578275, 0.40821073]
std = [0.26862954, 0.26130258, 0.27577711]
patch_size = patch_size if patch_size is not None else size
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean, std),
transforms.Resize(
(size, size),
interpolation=InterpolationMode.BICUBIC if interpolation_mode
== "bicubic" else InterpolationMode.BILINEAR,
antialias=True),
])
self.patch_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean, std),
transforms.Resize(
(patch_size, patch_size),
interpolation=InterpolationMode.BICUBIC if interpolation_mode
== "bicubic" else InterpolationMode.BILINEAR,
antialias=True),
]) if patch_size is not None else None
def __call__(self, image, is_patch=False):
if is_patch:
return {"pixel_values": self.patch_transform(image).unsqueeze(0)}
else:
return {"pixel_values": self.transform(image).unsqueeze(0)}
class ImagePatcher:
def determine_window_size(self, long: int, short: int) -> int:
if long <= 728:
return short if long / short > 1.5 else 0
return min(short, 504) if long / short > 4 else 504
def slide_window(
self,
width: int,
height: int,
sizes: list[tuple[int, int]],
steps: list[tuple[int, int]],
img_rate_thr: float = 0.6,
) -> tuple[list[tuple[int, int, int, int]], tuple[int, int]]:
assert 1 >= img_rate_thr >= 0, "The `in_rate_thr` should lie in 0~1"
windows = []
# Sliding windows.
for size, step in zip(sizes, steps):
size_w, size_h = size
step_w, step_h = step
x_num = 1 if width <= size_w else ceil((width - size_w) / step_w +
1)
x_start = [step_w * i for i in range(x_num)]
if len(x_start) > 1 and x_start[-1] + size_w > width:
x_start[-1] = width - size_w
y_num = 1 if height <= size_h else ceil((height - size_h) /
step_h + 1)
y_start = [step_h * i for i in range(y_num)]
if len(y_start) > 1 and y_start[-1] + size_h > height:
y_start[-1] = height - size_h
start = np.array(list(product(y_start, x_start)), dtype=int)
start[:, [0, 1]] = start[:, [1, 0]]
windows.append(np.concatenate([start, start + size], axis=1))
windows = np.concatenate(windows, axis=0)
return [(int(box[0]), int(box[1]), int(box[2] - box[0]),
int(box[3] - box[1])) for box in windows], (x_num, y_num)
def square_pad(self, img: Image.Image) -> Image.Image:
w, h = img.size
if w == h:
return img
size = max(w, h)
padded = Image.new(img.mode, (size, size), 0)
padded.paste(img, (0, 0))
return padded
def get_image_size_for_padding(self, img_width: int,
img_height: int) -> tuple[int, int]:
ratio = img_width / img_height
if min(img_height, img_width) < 32 and (ratio > 4 or ratio < 1 / 4):
new_size = max(img_height, img_width)
return new_size, new_size
return img_width, img_height
def get_image_size_for_preprocess(self, img_width: int,
img_height: int) -> tuple[int, int]:
if max(img_height, img_width) > MAX_IMAGE_SIZE:
scale_factor = MAX_IMAGE_SIZE / max(img_height, img_width)
img_width = int(img_width * scale_factor)
img_height = int(img_height * scale_factor)
return img_width, img_height
def get_image_size_for_crop(self, img_width: int, img_height: int,
window_size: int):
w_ratio = img_width / window_size
h_ratio = img_height / window_size
if w_ratio < 1:
width_new = img_width
else:
decimal_w = w_ratio - img_width // window_size
w_ratio = int(w_ratio) + 1 if decimal_w > 0.2 else int(w_ratio)
width_new = window_size * w_ratio
if h_ratio < 1:
height_new = img_height
else:
decimal_h = h_ratio - img_height // window_size
h_ratio = int(h_ratio) + 1 if decimal_h > 0.2 else int(h_ratio)
height_new = window_size * h_ratio
return int(width_new), int(height_new)
def patch_crop(self, img: Image.Image, i: int, j: int, th: int, tw: int):
target = img.crop((j, i, j + tw, i + th))
return target
def get_num_patches(self, img_width: int,
img_height: int) -> tuple[int, int]:
img_width, img_height = self.get_image_size_for_padding(
img_width, img_height)
img_width, img_height = self.get_image_size_for_preprocess(
img_width, img_height)
window_size = self.determine_window_size(max(img_height, img_width),
min(img_height, img_width))
if window_size == 0:
return 0, 0
else:
img_width, img_height = self.get_image_size_for_crop(
img_width, img_height, window_size)
center_list, (x_num, y_num) = self.slide_window(
img_width, img_height, [(window_size, window_size)],
[(window_size, window_size)])
full_rows = (len(center_list) - 1) // x_num + 1
if len(center_list) > 0 and len(center_list) % x_num == 0:
full_rows -= 1
return len(center_list), full_rows
def __call__(
self, img: Image.Image
) -> tuple[Image.Image, list[Image.Image], list[bool] | None]:
img_width, img_height = img.size
new_img_width, new_img_height = self.get_image_size_for_padding(
img_width, img_height)
if new_img_width != img_width or new_img_height != img_height:
img = self.square_pad(img)
img_width, img_height = img.size
new_img_width, new_img_height = self.get_image_size_for_preprocess(
img_width, img_height)
img = img.resize((new_img_width, new_img_height),
Image.Resampling.BILINEAR)
window_size = self.determine_window_size(
max(new_img_height, new_img_width),
min(new_img_height, new_img_width))
if window_size == 0:
return img, [], None
else:
new_img_width, new_img_height = self.get_image_size_for_crop(
new_img_width, new_img_height, window_size)
if (new_img_width, new_img_height) != (img_width, img_height):
img_for_crop = img.resize((new_img_width, new_img_height),
Image.Resampling.BILINEAR)
else:
img_for_crop = img
patches = []
newlines = []
center_list, (x_num, y_num) = self.slide_window(
new_img_width, new_img_height, [(window_size, window_size)],
[(window_size, window_size)])
for patch_id, center_lf_point in enumerate(center_list):
x, y, patch_w, patch_h = center_lf_point
big_patch = self.patch_crop(img_for_crop, y, x, patch_h,
patch_w)
patches.append(big_patch)
if (patch_id + 1) % x_num == 0:
newlines.append(patch_id)
if newlines and newlines[-1] == len(patches) - 1:
newlines.pop()
return img, patches, [i in newlines for i in range(len(patches))
] if len(patches) > 0 else None
class Step3VLProcessor:
def __init__(
self,
config: PretrainedConfig,
tokenizer: AnyTokenizer,
) -> None:
super().__init__()
self.config = config
self.tokenizer = tokenizer
self.image_size = 728
self.patch_size = 504
self.image_preprocessor = Step3VisionProcessor(self.image_size,
"bilinear",
self.patch_size)
self.num_image_feature_size = 169
self.num_patch_feature_size = 81
self.image_token = "<im_patch>"
self.image_feature_placeholder = (self.image_token *
self.num_image_feature_size)
self.patch_feature_placeholder = (self.image_token *
self.num_patch_feature_size)
self.patcher = ImagePatcher()
@property
def image_token_id(self) -> int:
return self.tokenizer.get_vocab()[self.image_token]
def get_num_image_tokens(self, img_width: int, img_height: int) -> int:
num_patches, num_newlines = self.patcher.get_num_patches(
img_width, img_height)
return num_patches * (
self.num_patch_feature_size +
2) + self.num_image_feature_size + 2 + num_newlines
def _split_images(self,
images: list[Image.Image]) -> list[ImageWithPatches]:
result = []
for img in images:
result.append(self.patcher(img))
return result
def _convert_images_to_pixel_values(
self,
images: list[Image.Image],
is_patch: bool = False,
) -> list[torch.Tensor]:
return [
self.image_preprocessor(img, is_patch=is_patch)["pixel_values"]
for img in images
]
def _get_patch_repl(
self,
num_patches: int,
patch_newline_mask: list[bool] | None,
) -> tuple[str, list[int]]:
text = ""
token_ids = []
for i in range(num_patches):
assert len(patch_newline_mask) == num_patches
text += f"<patch_start>{self.patch_feature_placeholder}<patch_end>"
token_ids.extend(
[self.tokenizer.convert_tokens_to_ids("<patch_start>")] +
[self.image_token_id] * self.num_patch_feature_size +
[self.tokenizer.convert_tokens_to_ids("<patch_end>")])
if patch_newline_mask and patch_newline_mask[i]:
text += "<patch_newline>"
token_ids.append(
self.tokenizer.convert_tokens_to_ids("<patch_newline>"))
return text, token_ids
def _get_image_repl(
self,
num_images: int,
) -> tuple[str, list[int]]:
text = f"<im_start>{self.image_feature_placeholder}<im_end>"
token_ids = [
self.tokenizer.convert_tokens_to_ids("<im_start>")
] + [self.image_token_id] * self.num_image_feature_size + [
self.tokenizer.convert_tokens_to_ids("<im_end>")
]
return text * num_images, token_ids * num_images
def _get_image_repl_features(
self,
num_images: int,
num_patches: int,
patch_new_line_idx: Optional[list[bool]],
) -> tuple[str, list[int]]:
if num_patches > 0:
patch_repl, patch_repl_ids = self._get_patch_repl(
num_patches, patch_new_line_idx)
else:
patch_repl = ""
patch_repl_ids = []
image_repl, image_repl_ids = self._get_image_repl(num_images)
return patch_repl + image_repl, patch_repl_ids + image_repl_ids
def replace_placeholder(self, text: str, placeholder: str,
repls: list[str]) -> str:
parts = text.split(placeholder)
if len(parts) - 1 != len(repls):
raise ValueError(
"The number of placeholders does not match the number of replacements." # noqa: E501
)
result = [parts[0]]
for i, repl in enumerate(repls):
result.append(repl)
result.append(parts[i + 1])
return "".join(result)
def __call__(
self,
text: Optional[Union[str, list[str]]] = None,
images: Optional[Union[Image.Image, list[Image.Image]]] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
) -> BatchFeature:
if text is None:
text = []
if not isinstance(text, list):
text = [text]
if images is None:
images = []
if not isinstance(images, list):
images = [images]
if len(images) == 0:
image_inputs = {}
text_inputs = self.tokenizer(text)
else:
splitted_images_data = self._split_images(images)
pixel_values_lst = []
patch_pixel_values_lst = []
patch_newline_mask_lst = []
image_repl_str_lst = []
image_repl_ids_lst = []
num_patches = []
for raw_img, img_patches, patch_newline_mask in splitted_images_data: # noqa: E501
pixel_values_lst.extend(
self._convert_images_to_pixel_values([raw_img]))
if len(img_patches) > 0:
patch_pixel_values_lst.extend(
self._convert_images_to_pixel_values(img_patches,
is_patch=True))
num_patches.append(len(img_patches))
image_repl_str, image_repl_ids = self._get_image_repl_features(
1, len(img_patches), patch_newline_mask)
image_repl_str_lst.append(image_repl_str)
image_repl_ids_lst.extend(image_repl_ids)
if patch_newline_mask is not None:
patch_newline_mask_lst.extend(patch_newline_mask)
image_inputs = {
"pixel_values": torch.cat(pixel_values_lst),
"num_patches": num_patches,
}
if patch_pixel_values_lst:
image_inputs["patch_pixel_values"] = torch.cat(
patch_pixel_values_lst)
if patch_newline_mask_lst:
image_inputs["patch_newline_mask"] = torch.tensor(
patch_newline_mask_lst, dtype=torch.bool)
text = [
self.replace_placeholder(t, self.image_token,
image_repl_str_lst) for t in text
]
text_inputs = self.tokenizer(text)
return BatchFeature(
{
**text_inputs,
**image_inputs,
},
tensor_type=return_tensors,
)
class Step3VLProcessingInfo(BaseProcessingInfo):
def get_hf_processor(self) -> Step3VLProcessor:
return Step3VLProcessor(
self.get_hf_config(),
self.get_tokenizer(),
)
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
def get_max_image_tokens(self) -> int:
hf_processor = self.get_hf_processor()
return hf_processor.get_num_image_tokens(
self.get_image_size_with_most_features().width,
self.get_image_size_with_most_features().height)
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {"image": self.get_max_image_tokens()}
def get_image_size_with_most_features(self) -> ImageSize:
return ImageSize(3024, 3024)
def get_num_mm_tokens(self, mm_data: MultiModalDataDict) -> int:
if len(mm_data) != 1 or "image" not in mm_data:
raise ValueError(
"mm_data could only contain one key 'image' for steo1o")
image_data = mm_data["image"]
if not isinstance(image_data, (list, tuple)):
image_data = [image_data]
return sum(self.get_hf_processor().get_num_image_tokens(
img.width, img.height) for img in image_data)
class Step3VLDummyInputsBuilder(BaseDummyInputsBuilder[Step3VLProcessingInfo]):
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_images = mm_counts.get("image", 0)
return "<im_patch>" * num_images
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> MultiModalDataDict:
target_width, target_height = \
self.info.get_image_size_with_most_features()
num_images = mm_counts.get("image", 0)
return {
"image":
self._get_dummy_images(width=target_width,
height=target_height,
num_images=num_images)
}
class Step3VLMultiModalProcessor(BaseMultiModalProcessor[Step3VLProcessingInfo]
):
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, Any],
out_mm_kwargs: MultiModalKwargs,
) -> Sequence[PromptUpdate]:
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
image_placeholder_token_id = hf_processor.image_token_id
batch_num_patches = out_mm_kwargs["num_patches"].tolist()
def get_replacement_step1o(item_idx: int):
img_out = out_mm_kwargs.get_item("image", item_idx)
num_patches = batch_num_patches[item_idx]
if num_patches > 0:
patch_newline_mask = img_out["patch_newline_mask"].data.tolist(
)
image_repl_ids = hf_processor._get_image_repl_features(
1, num_patches, patch_newline_mask)[1]
else:
image_repl_ids = hf_processor._get_image_repl_features(
1, 0, None)[1]
return PromptUpdateDetails.select_token_id(
seq=image_repl_ids,
embed_token_id=image_placeholder_token_id,
)
return [
PromptReplacement(
modality="image",
target=[image_placeholder_token_id],
replacement=get_replacement_step1o,
)
]
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
num_patches = hf_inputs.get("num_patches", torch.empty(0))
return dict(
pixel_values=MultiModalFieldConfig.batched("image"),
patch_pixel_values=MultiModalFieldConfig.flat_from_sizes(
"image", num_patches),
num_patches=MultiModalFieldConfig.batched("image"),
patch_newline_mask=MultiModalFieldConfig.flat_from_sizes(
"image", num_patches),
)
def get_abs_pos(abs_pos, tgt_size):
dim = abs_pos.size(-1)
abs_pos_new = abs_pos.squeeze(0)
cls_token, old_pos_embed = abs_pos_new[:1], abs_pos_new[1:]
src_size = int(math.sqrt(abs_pos_new.shape[0] - 1))
tgt_size = int(math.sqrt(tgt_size))
dtype = abs_pos.dtype
if src_size != tgt_size:
old_pos_embed = old_pos_embed.view(1, src_size, src_size,
dim).permute(0, 3, 1,
2).contiguous()
old_pos_embed = old_pos_embed.to(torch.float32)
new_pos_embed = F.interpolate(
old_pos_embed,
size=(tgt_size, tgt_size),
mode='bicubic',
antialias=True,
align_corners=False,
).to(dtype)
new_pos_embed = new_pos_embed.permute(0, 2, 3, 1)
new_pos_embed = new_pos_embed.view(tgt_size * tgt_size, dim)
vision_pos_embed = torch.cat([cls_token, new_pos_embed], dim=0)
vision_pos_embed = vision_pos_embed.view(1, tgt_size * tgt_size + 1,
dim)
return vision_pos_embed
else:
return abs_pos
class Step3VisionEmbeddings(nn.Module):
def __init__(self, config: Step3VisionEncoderConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.image_size = config.image_size
self.patch_size = config.patch_size
self.class_embedding = nn.Parameter(torch.randn(1, self.embed_dim))
self.patch_embedding = nn.Conv2d(
in_channels=config.num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
stride=self.patch_size,
bias=True,
)
self.num_patches = (self.image_size // self.patch_size)**2
self.pad_tp_size = 4 # hard code for padding
# To load the pretrained weights, we still use P+1 as the seqlen
self.position_embedding = torch.nn.Embedding(self.num_patches + 1,
self.embed_dim)
self.register_buffer("position_ids",
torch.arange(self.num_patches + 1).expand(
(1, -1)),
persistent=False)
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
batch_size = pixel_values.shape[0]
patch_embeds = self.patch_embedding(
pixel_values) # shape = [*, width, grid, grid]
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
# pad
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
embeddings = embeddings + get_abs_pos(
self.position_embedding(self.position_ids), patch_embeds.size(1))
embeddings = torch.cat([
embeddings[:, 0, :].unsqueeze(1).repeat(1, self.pad_tp_size - 1,
1), embeddings
],
dim=1)
return embeddings
class Step3VisionAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self,
config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.total_num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.total_num_heads
self.scale = self.head_dim**-0.5
tp_size = get_tensor_model_parallel_world_size()
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.qkv_proj = QKVParallelLinear(self.embed_dim,
self.head_dim,
self.total_num_heads,
bias=True,
quant_config=quant_config,
prefix=prefix)
self.out_proj = RowParallelLinear(self.embed_dim,
self.embed_dim,
bias=True,
quant_config=quant_config,
prefix=prefix)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads,
self.head_dim).transpose(1, 2).contiguous()
def forward(
self,
hidden_states: torch.Tensor,
):
"""Input shape: Batch x Time x Channel"""
bsz, tgt_len, _ = hidden_states.size()
# get query proj
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
q = q.view(bsz, tgt_len, self.num_heads, self.head_dim)
k = k.view(bsz, tgt_len, self.num_heads, self.head_dim)
v = v.view(bsz, tgt_len, self.num_heads, self.head_dim)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
attn_output = F.scaled_dot_product_attention(q,
k,
v,
scale=self.scale,
is_causal=False)
attn_output = attn_output.transpose(1, 2).reshape(
bsz, tgt_len, self.num_heads * self.head_dim)
attn_output, _ = self.out_proj(attn_output)
return attn_output
class Step3VisionMLP(nn.Module):
def __init__(self,
config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__()
self.config = config
self.activation_fn = get_act_fn(config.hidden_act)
self.fc1 = ColumnParallelLinear(config.hidden_size,
config.intermediate_size,
bias=True,
quant_config=quant_config,
prefix=prefix)
self.fc2 = RowParallelLinear(config.intermediate_size,
config.hidden_size,
bias=True,
quant_config=quant_config,
prefix=prefix)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states, _ = self.fc2(hidden_states)
return hidden_states
class Step3VisionEncoderLayer(nn.Module):
def __init__(self,
config: Step3VisionEncoderConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__()
self.embed_dim = config.hidden_size
self.self_attn = Step3VisionAttention(config,
quant_config,
prefix=f"{prefix}.self_attn")
self.layer_norm1 = nn.LayerNorm(self.embed_dim,
eps=config.layer_norm_eps)
self.mlp = Step3VisionMLP(config, quant_config, prefix=f"{prefix}.mlp")
self.layer_norm2 = nn.LayerNorm(self.embed_dim,
eps=config.layer_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
) -> torch.FloatTensor:
hidden_states = hidden_states + self.layer_norm1(
self.self_attn(hidden_states))
hidden_states = hidden_states + self.layer_norm2(
self.mlp(hidden_states))
return hidden_states
class Step3VisionEncoder(nn.Module):
def __init__(self,
config: Step3VisionEncoderConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__()
self.config = config
self.layers = nn.ModuleList([
Step3VisionEncoderLayer(config,
quant_config,
prefix=f"{prefix}.layers.{i}")
for i in range(config.num_hidden_layers)
])
def forward(
self,
inputs_embeds,
):
hidden_states = inputs_embeds
for encoder_layer in self.layers:
hidden_states = encoder_layer(hidden_states)
return hidden_states
class Step3VisionTransformer(nn.Module):
def __init__(self,
config: Step3VisionEncoderConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__()
self.config = config
self.image_size = config.image_size
self.embeddings = Step3VisionEmbeddings(config)
self.transformer = Step3VisionEncoder(config,
quant_config,
prefix=f"{prefix}.transformer")
def forward(
self,
pixel_values: torch.Tensor,
):
hidden_states = self.embeddings(pixel_values)
hidden_states = self.transformer(inputs_embeds=hidden_states)
return hidden_states
@MULTIMODAL_REGISTRY.register_processor(Step3VLMultiModalProcessor,
info=Step3VLProcessingInfo,
dummy_inputs=Step3VLDummyInputsBuilder)
class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP):
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={
"model.": "language_model.model.",
"lm_head.": "language_model.lm_head.",
})
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
if modality.startswith("image"):
return "<im_patch>"
raise ValueError("Only image modality is supported")
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__()
config = vllm_config.model_config.hf_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config
self.multimodal_config = multimodal_config
self.vision_model = Step3VisionTransformer(config.vision_config,
None,
prefix=maybe_prefix(
prefix, "vision_model"))
self.vit_downsampler = nn.Conv2d(
config.vision_config.hidden_size,
config.vision_config.output_hidden_size,
kernel_size=2,
stride=config.understand_projector_stride)
self.vit_downsampler2 = nn.Conv2d(
config.vision_config.output_hidden_size,
config.vision_config.output_hidden_size * 2,
kernel_size=3,
stride=2,
padding=1,
)
self.vit_large_projector = nn.Linear(
config.vision_config.output_hidden_size * 2,
config.hidden_size,
bias=config.projector_bias,
)
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"))
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
@cached_property
def sampler(self):
if hasattr(self.language_model, "sampler"):
return self.language_model.sampler
return get_sampler()
@property
def device(self):
return next(self.parameters()).device
@property
def dtype(self):
return next(self.parameters()).dtype
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[Step3VLImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
patch_pixel_values = kwargs.pop("patch_pixel_values", None)
num_patches = kwargs.pop("num_patches", None)
image_embeds = kwargs.pop("image_embeds", None)
if pixel_values is None and image_embeds is None:
return None
if pixel_values is not None:
pixel_values = flatten_bn(pixel_values, concat=True)
if pixel_values.dim() >= 3:
pixel_values = pixel_values.view(-1, *pixel_values.shape[-3:])
if patch_pixel_values is not None:
patch_pixel_values = flatten_bn(patch_pixel_values,
concat=True)
patch_pixel_values = patch_pixel_values.view(
-1, *patch_pixel_values.shape[-3:])
# Handle empty patch_pixel_values by setting to None
if patch_pixel_values.shape[0] == 0:
patch_pixel_values = None
num_patches = flatten_bn(num_patches, concat=True).tolist()
return Step3VLImagePixelInputs(
type="pixel_values",
pixel_values=pixel_values.to(self.dtype).to(self.device),
patch_pixel_values=patch_pixel_values.to(self.dtype).to(
self.device) if patch_pixel_values is not None else None,
num_patches=num_patches,
)
if image_embeds is not None:
if image_embeds.dim() == 2 or image_embeds.dim() >= 3:
image_embeds = image_embeds.view(-1, image_embeds.shape[-1])
else:
raise ValueError(
f"Unexpected shape for image_embeds: {image_embeds.shape}")
return Step3VLImageEmbeddingInputs(
type="image_embeds",
image_embeds=image_embeds.to(self.dtype).to(self.device),
)
return None
def _process_image_features(self,
image_features: torch.Tensor) -> torch.Tensor:
B, P = image_features.shape[:2]
HW = int(sqrt(P))
image_features = image_features.permute(0, 2, 1).view(B, -1, HW, HW)
image_features = self.vit_downsampler(image_features)
image_features = self.vit_downsampler2(image_features)
n_dim = image_features.size(1)
image_features = image_features.view(B, n_dim, -1).permute(0, 2, 1)
image_features = self.vit_large_projector(image_features)
return image_features
def _get_vision_model_output(self,
input_tensor: torch.Tensor) -> torch.Tensor:
return self.vision_model(input_tensor)[:, 4:]
def _process_image_input(
self, image_input: Step3VLImageInputs) -> tuple[torch.Tensor, ...]:
if image_input["type"] == "image_embeds":
image_features = image_input["image_embeds"]
else:
image_features = self._get_vision_model_output(
image_input["pixel_values"])
patch_image_features = self._get_vision_model_output(
image_input["patch_pixel_values"]
) if image_input["patch_pixel_values"] is not None else None
num_patches = image_input["num_patches"]
image_features = self._process_image_features(image_features)
patch_image_features = self._process_image_features(
patch_image_features) if patch_image_features is not None else None
merged_image_features = []
cur_patch_idx = 0
for i, num_patch in enumerate(num_patches):
cur_feature = []
if num_patch > 0:
patch_slice = patch_image_features[
cur_patch_idx:cur_patch_idx + num_patch]
cur_feature.append(patch_slice.view(-1, patch_slice.shape[-1]))
cur_feature.append(image_features[i].view(
-1, image_features.shape[-1]))
cur_patch_idx += num_patch
merged_image_features.append(
torch.cat(cur_feature) if len(cur_feature) >
1 else cur_feature[0])
return merged_image_features
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
vision_embeddings = self._process_image_input(image_input)
return vision_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
if multimodal_embeddings is None:
inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids)
else:
is_text = input_ids != self.config.image_token_id
text_ids = input_ids[is_text]
text_embeds = self.language_model.model.get_input_embeddings(
text_ids)
inputs_embeds = torch.empty(input_ids.shape[0],
text_embeds.shape[-1],
dtype=text_embeds.dtype,
device=text_embeds.device)
inputs_embeds[is_text] = text_embeds
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
self.config.image_token_id)
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object,
) -> Union[torch.Tensor, IntermediateTensors]:
if intermediate_tensors is not None:
inputs_embeds = None
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
# always pass the input via `inputs_embeds`
# to make sure the computation graph is consistent
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
input_ids = None
hidden_states = self.language_model(input_ids,
positions,
intermediate_tensors,
inputs_embeds=inputs_embeds)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
return self.language_model.compute_logits(hidden_states,
sampling_metadata)
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
return self.language_model.sample(logits, sampling_metadata)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self)
loaded_weights = loader.load_weights(weights,
mapper=self.hf_to_vllm_mapper)
return loaded_weights
\ No newline at end of file
......@@ -6,6 +6,7 @@ from .deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser
from .glm4_moe_reasoning_parser import Glm4MoeModelReasoningParser
from .granite_reasoning_parser import GraniteReasoningParser
from .qwen3_reasoning_parser import Qwen3ReasoningParser
from .step3_reasoning_parser import Step3ReasoningParser
__all__ = [
"ReasoningParser",
......@@ -14,4 +15,5 @@ __all__ = [
"GraniteReasoningParser",
"Qwen3ReasoningParser",
"Glm4MoeModelReasoningParser",
"Step3ReasoningParser",
]
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence
from typing import Optional, Union
import regex as re
from transformers import PreTrainedTokenizerBase
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaMessage)
from vllm.logger import init_logger
from vllm.reasoning import ReasoningParser, ReasoningParserManager
logger = init_logger(__name__)
@ReasoningParserManager.register_module("step3")
class Step3ReasoningParser(ReasoningParser):
"""
Reasoning parser for Step3 model.
The Step3 model uses </think> token to denote the end of reasoning
text. This parser extracts all content before </think> as reasoning content.
"""
def __init__(self, tokenizer: PreTrainedTokenizerBase):
super().__init__(tokenizer)
self.think_end_token = "</think>"
self.reasoning_regex = re.compile(rf"(.*?){self.think_end_token}",
re.DOTALL)
if not self.model_tokenizer:
raise ValueError(
"The model tokenizer must be passed to the ReasoningParser "
"constructor during construction.")
self.think_end_token_id = self.vocab.get(self.think_end_token)
if self.think_end_token_id is None:
raise RuntimeError(
"Step3 reasoning parser could not locate think end "
"token in the tokenizer!")
def extract_reasoning_content_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
) -> Union[DeltaMessage, None]:
"""
Extract reasoning content from a delta message.
Handles streaming output where previous + delta = current.
Uses token IDs for faster processing.
For text "abc</think>xyz":
- 'abc' goes to reasoning_content
- 'xyz' goes to content
"""
# Skip single special token
if len(delta_token_ids
) == 1 and delta_token_ids[0] == self.think_end_token_id:
return None
if self.think_end_token_id in delta_token_ids:
# </think> in delta, extract reasoning content and remaining content
end_index = delta_text.find(self.think_end_token)
reasoning_content = delta_text[:end_index]
content = delta_text[end_index + len(self.think_end_token):]
return DeltaMessage(reasoning_content=reasoning_content,
content=content if content else None)
elif self.think_end_token_id in previous_token_ids:
# </think> already seen in previous text, everything is content
return DeltaMessage(content=delta_text)
else:
# No </think> seen yet, everything is reasoning
return DeltaMessage(reasoning_content=delta_text)
def extract_reasoning_content(
self, model_output: str, request: ChatCompletionRequest
) -> tuple[Optional[str], Optional[str]]:
# Check if the model output contains the </think> token
if self.think_end_token not in model_output:
# If no </think> token, everything is reasoning content
return model_output, None
else:
# Find the first occurrence of </think>
end_index = model_output.find(self.think_end_token)
reasoning_content = model_output[:end_index]
# Content after </think> token
content = model_output[end_index + len(self.think_end_token):]
if len(content) == 0:
content = None
return reasoning_content, content
def is_reasoning_end(self, input_ids: list[int]) -> bool:
return self.think_end_token_id in input_ids
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
if self.think_end_token_id not in input_ids[:-1]:
return []
else:
return input_ids[input_ids.index(self.think_end_token_id) + 1:]
\ No newline at end of file
......@@ -39,6 +39,7 @@ from vllm.transformers_utils.configs import (ChatGLMConfig, Cohere2Config,
MLPSpeculatorConfig, MPTConfig,
NemotronConfig, NVLM_D_Config,
OvisConfig, RWConfig,
Step3TextConfig, Step3VLConfig,
SkyworkR1VChatConfig, SolarConfig,
Telechat2Config, UltravoxConfig)
# yapf: enable
......@@ -97,6 +98,8 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = {
"skywork_chat": SkyworkR1VChatConfig,
"telechat": Telechat2Config,
"ultravox": UltravoxConfig,
"step3_vl": Step3VLConfig,
"step3_text": Step3TextConfig,
**_CONFIG_REGISTRY_OVERRIDE_HF
}
......
......@@ -28,6 +28,9 @@ from vllm.transformers_utils.configs.ovis import OvisConfig
from vllm.transformers_utils.configs.skyworkr1v import SkyworkR1VChatConfig
from vllm.transformers_utils.configs.solar import SolarConfig
from vllm.transformers_utils.configs.telechat2 import Telechat2Config
from vllm.transformers_utils.configs.step3_vl import (Step3TextConfig,
Step3VisionEncoderConfig,
Step3VLConfig)
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
__all__ = [
......@@ -56,4 +59,7 @@ __all__ = [
"SolarConfig",
"Telechat2Config",
"UltravoxConfig",
"Step3VLConfig",
"Step3VisionEncoderConfig",
"Step3TextConfig",
]
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Optional, Union
from transformers.configuration_utils import PretrainedConfig
class Step3VisionEncoderConfig(PretrainedConfig):
model_type = "step3_vision_encoder"
def __init__(
self,
hidden_size=1792,
intermediate_size=3072,
output_hidden_size=4096,
num_hidden_layers=63,
num_attention_heads=16,
num_channels=3,
image_size=728,
patch_size=14,
hidden_act="quick_gelu",
layer_norm_eps=1e-5,
**kwargs,
):
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.output_hidden_size = output_hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_channels = num_channels
self.patch_size = patch_size
self.image_size = image_size
self.layer_norm_eps = layer_norm_eps
self.hidden_act = hidden_act
super().__init__(**kwargs)
class Step3TextConfig(PretrainedConfig):
model_type = "step3_text"
architectures = ["Step3TextForCausalLM"]
def __init__(
self,
hidden_size: int = 7168,
intermediate_size: int = 18432,
num_attention_heads: int = 64,
num_attention_groups: int = 1,
num_hidden_layers: int = 61,
max_seq_len: int = 65536,
vocab_size: int = 128815,
rms_norm_eps: float = 1e-5,
moe_intermediate_size: int = 5120,
moe_num_experts: int = 48,
moe_top_k: int = 3,
rope_theta: float = 500000,
rope_scaling: Optional[dict[str, Any]] = None,
max_position_embedding: int = 65536,
share_expert_dim: int = 5120,
share_q_dim: int = 2048,
head_dim: int = 256,
norm_expert_weight: bool = False,
moe_layers_enum: tuple[int,
...] = (4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14,
15, 16, 17, 18, 19, 20, 21, 22, 23, 24,
25, 26, 27, 28, 29, 30, 31, 32, 33, 34,
35, 36, 37, 38, 39, 40, 41, 42, 43, 44,
45, 46, 47, 48, 49, 50, 51, 52, 53, 54,
55, 56, 57, 58, 59),
**kwargs,
) -> None:
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_attention_heads = num_attention_heads
self.num_attention_groups = num_attention_groups
self.num_hidden_layers = num_hidden_layers
self.max_seq_len = max_seq_len
self.vocab_size = vocab_size
self.rms_norm_eps = rms_norm_eps
self.moe_intermediate_size = moe_intermediate_size
self.moe_num_experts = moe_num_experts
self.moe_top_k = moe_top_k
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.max_position_embedding = max_position_embedding
self.share_expert_dim = share_expert_dim
self.share_q_dim = share_q_dim
self.head_dim = head_dim
self.norm_expert_weight = norm_expert_weight
self.moe_layers_enum = moe_layers_enum
super().__init__(**kwargs)
class Step3VLConfig(PretrainedConfig):
model_type = "step3_vl"
def __init__(
self,
vision_config: Optional[Union[dict, Step3VisionEncoderConfig]] = None,
text_config: Optional[Union[dict, Step3TextConfig]] = None,
understand_projector_stride: int = 1,
projector_bias: bool = True,
image_token_id: int = 128001,
**kwargs,
) -> None:
if vision_config is None:
vision_config = Step3VisionEncoderConfig()
elif isinstance(vision_config, dict):
vision_config = Step3VisionEncoderConfig(**vision_config)
self.vision_config = vision_config
if text_config is None:
text_config = Step3TextConfig()
elif isinstance(text_config, dict):
text_config = Step3TextConfig(**text_config)
self.text_config = text_config
self.understand_projector_stride = understand_projector_stride
self.projector_bias = projector_bias
self.hidden_size = text_config.hidden_size
self.image_token_id = image_token_id
super().__init__(**kwargs)
\ No newline at end of file
......@@ -690,7 +690,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
def can_run_in_cudagraph(
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
#return common_attn_metadata.max_query_len == 1
if not self.use_spec_decode:
return common_attn_metadata.max_query_len == 1
return self._num_prefills == 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment