Unverified Commit 4f564b9e authored by Zheng Li's avatar Zheng Li Committed by GitHub
Browse files

model: support qwen3-vl series (#10323)


Co-authored-by: default avatarocss884 <ocss.lin@gmail.com>
Co-authored-by: default avatarcao1zhg <653506626@qq.com>
Co-authored-by: default avataryhyang201 <yhyang201@gmail.com>
Co-authored-by: default avataryhyang201 <47235274+yhyang201@users.noreply.github.com>
Co-authored-by: default avatar瑀澈 <yuche.lz@alibaba-inc.com>
Co-authored-by: default avatarMick <mickjagger19@icloud.com>
Co-authored-by: default avatarYineng Zhang <me@zhyncs.com>
parent 98c3b04f
......@@ -749,6 +749,8 @@ multimodal_model_archs = [
"Qwen2AudioForConditionalGeneration",
"Qwen2VLForConditionalGeneration",
"Qwen2_5_VLForConditionalGeneration",
"Qwen3VLForConditionalGeneration",
"Qwen3VLMoeForConditionalGeneration",
"KimiVLForConditionalGeneration",
"InternVLChatModel",
"InternS1ForConditionalGeneration",
......
This diff is collapsed.
......@@ -1187,7 +1187,7 @@ class MRotaryEmbedding(RotaryEmbedding):
time_tensor_long = time_tensor.long()
t_index = time_tensor_long.flatten()
elif model_type == "qwen2_vl":
elif model_type in ("qwen2_vl", "qwen3_vl", "qwen3_vl_moe"):
t_index = (
torch.arange(llm_grid_t)
.view(-1, 1)
......
......@@ -507,6 +507,7 @@ def embed_mm_inputs(
Modality, Callable[[List[MultimodalDataItem]], torch.Tensor]
] = None,
placeholder_tokens: dict[Modality, List[int]] = None,
use_deepstack: bool = False,
) -> Optional[torch.Tensor]:
"""
Embed multimodal inputs and integrate them with text token embeddings.
......@@ -522,7 +523,7 @@ def embed_mm_inputs(
Returns:
Combined embedding tensor with multimodal content integrated
"""
other_info = {}
if mm_inputs_list is None:
return None
......@@ -532,7 +533,7 @@ def embed_mm_inputs(
for mm_inputs in mm_inputs_list:
item_flatten_list += [item for item in mm_inputs.mm_items if item is not None]
embeddings, masks = [], []
embeddings, masks, deepstack_embeddings = [], [], []
# 2. Get multimodal embedding separately
# Try get mm embedding if any
for modality in Modality.all():
......@@ -578,6 +579,12 @@ def embed_mm_inputs(
extend_length=extend_seq_lens,
items_offset_list=items_offsets,
)
if use_deepstack and embedding is not None:
embedding, deepstack_embedding = (
multimodal_model.separate_deepstack_embeds(embedding)
)
deepstack_embeddings += [deepstack_embedding]
embeddings += [embedding]
masks += [mask]
......@@ -591,13 +598,37 @@ def embed_mm_inputs(
inputs_embeds = input_embedding(input_ids)
# 4. scatter embeddings into input embedding
for embedding, mask in zip(embeddings, masks):
# deepstack embedding
if use_deepstack:
num_deepstack_embeddings = (
len(multimodal_model.deepstack_visual_indexes) if use_deepstack else 0
)
deepstack_embedding_shape = inputs_embeds.shape[:-1] + (
inputs_embeds.shape[-1] * num_deepstack_embeddings,
)
input_deepstack_embeds = torch.zeros(
deepstack_embedding_shape,
device=inputs_embeds.device,
dtype=inputs_embeds.dtype,
)
other_info["input_deepstack_embeds"] = input_deepstack_embeds
for i, embedding, mask in zip(range(len(embeddings)), embeddings, masks):
if embedding is None or mask is None:
continue
# in-place update
indices = torch.where(mask.squeeze(dim=-1))[0]
inputs_embeds[indices] = embedding.to(inputs_embeds.device, inputs_embeds.dtype)
return inputs_embeds
if use_deepstack:
input_deepstack_embeds[indices] = deepstack_embeddings[i].to(
inputs_embeds.device, inputs_embeds.dtype
)
return inputs_embeds, other_info
def general_mm_embed_routine(
......@@ -609,6 +640,7 @@ def general_mm_embed_routine(
Modality, Callable[[List[MultimodalDataItem]], torch.Tensor]
] = None,
placeholder_tokens: Optional[dict[Modality, List[int]]] = None,
use_deepstack: bool = False,
**kwargs,
) -> torch.Tensor:
"""
......@@ -620,6 +652,7 @@ def general_mm_embed_routine(
language_model: Base language model to use
data_embedding_funcs: A dictionary mapping from modality type to the corresponding embedding function.
placeholder_tokens: Token IDs for multimodal placeholders
use_deepstack: Whether to use deepstack embeddings
**kwargs: Additional arguments passed to language model
Returns:
......@@ -645,16 +678,20 @@ def general_mm_embed_routine(
for i, seq_len in enumerate(forward_batch.extend_seq_lens_cpu)
if forward_batch.mm_inputs[i] is not None
]
inputs_embeds = embed_mm_inputs(
inputs_embeds, other_info = embed_mm_inputs(
mm_inputs_list=mm_inputs_list,
extend_prefix_lens=extend_prefix_lens,
extend_seq_lens=extend_seq_lens,
input_ids=input_ids,
input_embedding=embed_tokens,
multimodal_model=multimodal_model,
input_embedding=embed_tokens,
data_embedding_func_mapping=data_embedding_funcs,
placeholder_tokens=placeholder_tokens,
use_deepstack=use_deepstack,
)
# add for qwen3_vl deepstack
if use_deepstack:
kwargs["input_deepstack_embeds"] = other_info["input_deepstack_embeds"]
# once used, mm_inputs is useless, considering chunked-prefill is disabled for multimodal models
# just being defensive here
forward_batch.mm_inputs = None
......
This diff is collapsed.
# Copyright 2025 Qwen Team
# Copyright 2025 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Inference-only Qwen3-VL model compatible with HuggingFace weights."""
import logging
from functools import lru_cache, partial
from typing import Callable, Iterable, List, Literal, Optional, Tuple, TypedDict, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from transformers import BatchFeature
from transformers.activations import ACT2FN
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
Qwen2_5_VisionRotaryEmbedding,
)
from sglang.srt.configs.qwen3_vl import Qwen3VLMoeConfig, Qwen3VLMoeVisionConfig
from sglang.srt.distributed import (
get_moe_expert_parallel_world_size,
get_pp_group,
get_tensor_model_parallel_rank,
)
from sglang.srt.hf_transformers_utils import get_processor
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.utils import get_layer_id
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternMultimodalTokens,
general_mm_embed_routine,
)
from sglang.srt.managers.schedule_batch import (
MultimodalDataItem,
MultimodalInputs,
global_server_args_dict,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.qwen3_moe import Qwen3MoeForCausalLM, Qwen3MoeModel
from sglang.srt.models.qwen3_vl import (
Qwen3_VisionTransformer,
Qwen3VLForConditionalGeneration,
)
from sglang.srt.utils import add_prefix
logger = logging.getLogger(__name__)
cached_get_processor = lru_cache(get_processor)
class Qwen3MoeLLMModel(Qwen3MoeModel):
def __init__(
self,
*,
config: Qwen3VLMoeConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__(config=config, quant_config=quant_config, prefix=prefix)
self.hidden_size = config.hidden_size
def get_input_embeddings(self) -> nn.Embedding:
return self.embed_tokens
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
# in qwen-vl, last dim is the same
pixel_values = torch.cat([item.feature for item in items], dim=0).type(
self.visual.dtype
)
image_grid_thw = torch.concat([item.image_grid_thw for item in items], dim=0)
assert pixel_values.dim() == 2, pixel_values.dim()
assert image_grid_thw.dim() == 2, image_grid_thw.dim()
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
return image_embeds
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
pp_proxy_tensors: Optional[PPProxyTensors] = None,
input_deepstack_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, PPProxyTensors]:
if self.pp_group.is_first_rank:
if input_embeds is None:
hidden_states = self.embed_tokens(input_ids)
else:
hidden_states = input_embeds
residual = None
else:
assert pp_proxy_tensors is not None
hidden_states = pp_proxy_tensors["hidden_states"]
residual = pp_proxy_tensors["residual"]
aux_hidden_states = []
for layer_idx, layer in enumerate(
self.layers[self.start_layer : self.end_layer]
):
layer_idx = layer_idx + self.start_layer
if layer_idx in self.layers_to_capture:
aux_hidden_states.append(
hidden_states + residual if residual is not None else hidden_states
)
hidden_states, residual = layer(
positions,
hidden_states,
forward_batch,
residual,
)
# process deepstack
if input_deepstack_embeds is not None and layer_idx in range(3):
sep = self.hidden_size * layer_idx
hidden_states = (
hidden_states
+ input_deepstack_embeds[:, sep : sep + self.hidden_size]
)
if not self.pp_group.is_last_rank:
return PPProxyTensors(
{
"hidden_states": hidden_states,
"residual": residual,
}
)
else:
if hidden_states.shape[0] != 0:
if residual is None:
hidden_states = self.norm(hidden_states)
else:
hidden_states, _ = self.norm(hidden_states, residual)
if len(aux_hidden_states) == 0:
return hidden_states
return hidden_states, aux_hidden_states
class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):
def __init__(
self,
*,
config: Qwen3VLMoeConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super(Qwen3VLForConditionalGeneration, self).__init__()
self.config = config
self.visual = Qwen3_VisionTransformer(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
# NOTE: Qwen3-VL vision encoder currently supports BitsAndBytes 4-bit quantization.
# Other quantization methods (e.g., GPTQ, AWQ) are untested and may not be supported.
quant_config=quant_config,
prefix=add_prefix("visual", prefix),
)
self.model = Qwen3MoeLLMModel(
config=config,
quant_config=quant_config,
prefix=add_prefix("model", prefix),
)
if config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
)
self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling
self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
# deepstack
self.deepstack_visual_indexes = self.visual.deepstack_visual_indexes
self.num_deepstack_embeddings = len(self.deepstack_visual_indexes)
@property
def use_deepstack(self) -> bool:
return hasattr(self, "deepstack_visual_indexes")
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
get_embedding: bool = False,
):
"""Run forward pass for Qwen3-VL.
Args:
input_ids: Flattened (concatenated) input_ids corresponding to a
batch.
positions: Flattened (concatenated) position ids corresponding to a
batch.
**NOTE**: If mrope is enabled (default setting for Qwen2-VL
opensource models), the shape will be `(3, seq_len)`,
otherwise it will be `(seq_len,).
(Use input_metadata.mrope_positions to replace it)
"""
if self.is_mrope_enabled:
positions = forward_batch.mrope_positions
if not (
forward_batch.forward_mode.is_decode()
or not forward_batch.contains_image_inputs()
):
if self.is_mrope_enabled:
assert positions.ndim == 2 and positions.size(0) == 3, (
"multimodal section rotary embedding requires "
f"(3, seq_len) positions, but got {positions.size()}"
)
hidden_states = general_mm_embed_routine(
input_ids=input_ids,
forward_batch=forward_batch,
language_model=self.model,
multimodal_model=self,
positions=positions,
use_deepstack=self.use_deepstack,
)
if not get_embedding:
return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch
)
else:
return self.pooler(hidden_states, forward_batch)
def load_fused_expert_weights(
self,
name: str,
params_dict: dict,
loaded_weight: torch.Tensor,
shard_id: str,
num_experts: int,
):
param = params_dict[name]
# weight_loader = typing.cast(Callable[..., bool], param.weight_loader)
weight_loader = param.weight_loader
ep_rank = get_tensor_model_parallel_rank()
ep_size = get_moe_expert_parallel_world_size()
if ep_size == 1:
for expert_id in range(num_experts):
curr_expert_weight = loaded_weight[expert_id]
weight_loader(
param,
curr_expert_weight,
name,
shard_id,
expert_id,
)
else:
experts_per_ep = num_experts // ep_size
start_expert = ep_rank * experts_per_ep
end_expert = (
(ep_rank + 1) * experts_per_ep
if ep_rank != ep_size - 1
else num_experts
)
for idx, expert_id in enumerate(range(start_expert, end_expert)):
curr_expert_weight = loaded_weight[expert_id]
weight_loader(
param,
curr_expert_weight,
name,
shard_id,
idx,
)
return True
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
("gate_up_proj", "up_proj", 1),
("gate_up_proj", "gate_proj", 0),
]
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.num_experts,
)
# Skip loading extra parameters for GPTQ/modelopt models.
ignore_suffixes = (
".bias",
"_bias",
".k_scale",
"_k_scale",
".v_scale",
"_v_scale",
".weight_scale",
"_weight_scale",
".input_scale",
"_input_scale",
)
is_fused_expert = False
fused_expert_params_mapping = [
("experts.w13_weight", "experts.gate_up_proj", 0, "w1"),
("experts.w2_weight", "experts.down_proj", 0, "w2"),
]
num_experts = self.config.num_experts
# Cache params_dict to avoid repeated expensive traversal of model parameters
if not hasattr(self, "_cached_params_dict"):
self._cached_params_dict = dict(self.named_parameters())
params_dict = self._cached_params_dict
for name, loaded_weight in weights:
if "language_model" in name:
name = name.replace(r"model.language_model.", r"model.")
for param_name, weight_name, shard_id in stacked_params_mapping:
if "experts.gate_up_proj" in name or "experts.down_proj" in name:
is_fused_expert = True
expert_params_mapping = fused_expert_params_mapping
# Skip non-stacked layers and experts (experts handled below).
if weight_name not in name:
continue
if "visual" in name:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if "mlp.experts" in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra parameters for GPTQ/modelopt models.
if name.endswith(ignore_suffixes) and name not in params_dict:
continue
# [TODO] Skip layers that are on other devices (check if sglang has a similar function)
# if is_pp_missing_parameter(name, self):
# continue
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Track if this is an expert weight to enable early skipping
is_expert_weight = False
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
if "visual" in name:
continue
# Anyway, this is an expert weight and should not be
# attempted to load as other weights later
is_expert_weight = True
name_mapped = name.replace(weight_name, param_name)
if is_fused_expert:
loaded_weight = loaded_weight.transpose(-1, -2) # no bias
if "experts.gate_up_proj" in name:
loaded_weight = loaded_weight.chunk(2, dim=-2)
self.load_fused_expert_weights(
name_mapped,
params_dict,
loaded_weight[0],
"w1",
num_experts,
)
self.load_fused_expert_weights(
name_mapped,
params_dict,
loaded_weight[1],
"w3",
num_experts,
)
else:
self.load_fused_expert_weights(
name_mapped,
params_dict,
loaded_weight,
shard_id,
num_experts,
)
else:
# Skip loading extra parameters for GPTQ/modelopt models.
if (
name_mapped.endswith(ignore_suffixes)
and name_mapped not in params_dict
):
continue
param = params_dict[name_mapped]
# We should ask the weight loader to return success or
# not here since otherwise we may skip experts with
# # other available replicas.
weight_loader = param.weight_loader
weight_loader(
param,
loaded_weight,
name_mapped,
shard_id=shard_id,
expert_id=expert_id,
)
name = name_mapped
break
else:
if is_expert_weight:
# This is an expert weight but not mapped to this rank, skip all remaining processing
continue
if "visual" in name:
# adapt to VisionAttention
name = name.replace(r"attn.qkv.", r"attn.qkv_proj.")
name = name.replace(r"model.visual.", r"visual.")
# Skip loading extra parameters for GPTQ/modelopt models.
if name.endswith(ignore_suffixes) and name not in params_dict:
continue
if name in params_dict.keys():
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
else:
logger.warning(f"Parameter {name} not found in params_dict")
# TODO mimic deepseek
# Lazy initialization of expert weights cache to avoid slowing down load_weights
# if not hasattr(self, "routed_experts_weights_of_layer"):
# self.routed_experts_weights_of_layer = {
# layer_id: self.model.layers[layer_id].mlp.get_moe_weights()
# for layer_id in range(self.start_layer, self.end_layer)
# if isinstance(self.model.layers[layer_id].mlp, Qwen3MoeSparseMoeBlock)
# }
EntryClass = Qwen3VLMoeForConditionalGeneration
......@@ -12,6 +12,8 @@ from torchvision.transforms import InterpolationMode
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
from sglang.srt.models.qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
from sglang.srt.models.qwen2_vl import Qwen2VLForConditionalGeneration
from sglang.srt.models.qwen3_vl import Qwen3VLForConditionalGeneration
from sglang.srt.models.qwen3_vl_moe import Qwen3VLMoeForConditionalGeneration
from sglang.srt.multimodal.processors.base_processor import (
BaseMultimodalProcessor as SGLangBaseProcessor,
)
......@@ -209,7 +211,12 @@ async def preprocess_video(
# Compatible with Qwen2VL and Qwen2_5VL
class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
models = [Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration]
models = [
Qwen2VLForConditionalGeneration,
Qwen2_5_VLForConditionalGeneration,
Qwen3VLForConditionalGeneration,
Qwen3VLMoeForConditionalGeneration,
]
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment