Unverified Commit b7094a5e authored by RunningLeon's avatar RunningLeon Committed by GitHub
Browse files

model: support intern-s1 (#8350)


Signed-off-by: default avatarXinyuan Tong <xinyuantong.cs@gmail.com>
Co-authored-by: default avatarzxy <zhou0493@e.ntu.edu.sg>
Co-authored-by: default avatarXinyuan Tong <xinyuantong.cs@gmail.com>
Co-authored-by: default avatarMick <mickjagger19@icloud.com>
Co-authored-by: default avatarXinyuan Tong <115166877+JustinTong0323@users.noreply.github.com>
parent da0c0260
...@@ -448,6 +448,19 @@ register_chat_template( ...@@ -448,6 +448,19 @@ register_chat_template(
) )
) )
register_chat_template(
ChatTemplate(
name="interns1",
default_system_prompt="You are an AI assistant whose name is Intern-S1 (书生大模型).\n- Intern-S1 (书生大模型) is a vision-language model that is developed by Shanghai AI Laboratory (上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n- Intern-S1 (书生大模型) can understand and communicate fluently in the language chosen by the user such as English and 中文.\nYou are an expert reasoner with extensive experience in all areas. You approach problems through systematic thinking and rigorous reasoning. Your response should reflect deep understanding and precise logical thinking, making your solution path and reasoning clear to others. Please put your thinking process within <think>...</think> tags.",
role_prefix_and_suffix={
"system": ("<|im_start|>system\n", "<|im_end|>\n"),
"user": ("<|im_start|>user\n", "<|im_end|>\n"),
"assistant": ("<|im_start|>assistant\n", "<|im_end|>\n"),
},
stop_str=["<|im_end|>", "<|action_end|>"],
)
)
register_chat_template( register_chat_template(
ChatTemplate( ChatTemplate(
name="granite-3-instruct", name="granite-3-instruct",
...@@ -609,6 +622,14 @@ def match_internvl_chat(model_path: str): ...@@ -609,6 +622,14 @@ def match_internvl_chat(model_path: str):
return "internvl-2-5" return "internvl-2-5"
@register_chat_template_matching_function
def match_interns1_chat(model_path: str):
if re.search(r"intern-s1", model_path, re.IGNORECASE):
return "interns1"
if re.search(r"interns1", model_path, re.IGNORECASE):
return "interns1"
if __name__ == "__main__": if __name__ == "__main__":
messages = [ messages = [
{"role": "system", "content": None}, # None means default {"role": "system", "content": None}, # None means default
......
...@@ -10,6 +10,7 @@ from transformers import ( ...@@ -10,6 +10,7 @@ from transformers import (
PretrainedConfig, PretrainedConfig,
PreTrainedTokenizer, PreTrainedTokenizer,
Qwen2Config, Qwen2Config,
Qwen3Config,
) )
from sglang.utils import logger from sglang.utils import logger
...@@ -314,6 +315,8 @@ class InternVLChatConfig(PretrainedConfig): ...@@ -314,6 +315,8 @@ class InternVLChatConfig(PretrainedConfig):
self.llm_config = InternLM2Config(**llm_config) self.llm_config = InternLM2Config(**llm_config)
elif llm_config.get("architectures")[0] == "Qwen2ForCausalLM": elif llm_config.get("architectures")[0] == "Qwen2ForCausalLM":
self.llm_config = Qwen2Config(**llm_config) self.llm_config = Qwen2Config(**llm_config)
elif llm_config.get("architectures")[0] == "Qwen3MoeForCausalLM":
self.llm_config = Qwen3Config(**llm_config)
else: else:
raise ValueError( raise ValueError(
"Unsupported architecture: {}".format( "Unsupported architecture: {}".format(
......
...@@ -635,6 +635,7 @@ multimodal_model_archs = [ ...@@ -635,6 +635,7 @@ multimodal_model_archs = [
"Qwen2_5_VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration",
"KimiVLForConditionalGeneration", "KimiVLForConditionalGeneration",
"InternVLChatModel", "InternVLChatModel",
"InternS1ForConditionalGeneration",
"Phi4MMForCausalLM", "Phi4MMForCausalLM",
"VILAForConditionalGeneration", "VILAForConditionalGeneration",
] ]
......
...@@ -623,7 +623,7 @@ def generate_chat_conv( ...@@ -623,7 +623,7 @@ def generate_chat_conv(
real_content += content.text real_content += content.text
elif content.type == "image_url": elif content.type == "image_url":
# NOTE: works for llava and intervl2_5 # NOTE: works for llava and intervl2_5
if conv.name == "internvl-2-5": if conv.name in ["internvl-2-5", "interns1"]:
real_content = image_token + real_content real_content = image_token + real_content
else: else:
real_content += image_token real_content += image_token
...@@ -817,6 +817,19 @@ register_conv_template( ...@@ -817,6 +817,19 @@ register_conv_template(
) )
) )
register_conv_template(
Conversation(
name="interns1",
system_template="<|im_start|>system\n{system_message}",
system_message="You are an AI assistant whose name is Intern-S1 (书生大模型).\n- Intern-S1 (书生大模型) is a vision-language model that is developed by Shanghai AI Laboratory (上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n- Intern-S1 (书生大模型) can understand and communicate fluently in the language chosen by the user such as English and 中文.\nYou are an expert reasoner with extensive experience in all areas. You approach problems through systematic thinking and rigorous reasoning. Your response should reflect deep understanding and precise logical thinking, making your solution path and reasoning clear to others. Please put your thinking process within <think>...</think> tags.",
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
sep_style=SeparatorStyle.MPT,
sep="<|im_end|>\n",
stop_str=["<|im_end|>", "<|action_end|>"],
image_token="<image>",
)
)
# Reference: https://huggingface.co/docs/transformers/main/model_doc/qwen2_vl#usage-example # Reference: https://huggingface.co/docs/transformers/main/model_doc/qwen2_vl#usage-example
register_conv_template( register_conv_template(
Conversation( Conversation(
...@@ -986,6 +999,8 @@ register_conv_template( ...@@ -986,6 +999,8 @@ register_conv_template(
def match_internvl(model_path: str): def match_internvl(model_path: str):
if re.search(r"internvl", model_path, re.IGNORECASE): if re.search(r"internvl", model_path, re.IGNORECASE):
return "internvl-2-5" return "internvl-2-5"
if re.search(r"interns1", model_path, re.IGNORECASE):
return "interns1"
@register_conv_template_matching_function @register_conv_template_matching_function
......
...@@ -3,7 +3,7 @@ from __future__ import annotations ...@@ -3,7 +3,7 @@ from __future__ import annotations
import dataclasses import dataclasses
import functools import functools
import math import math
from functools import lru_cache from functools import lru_cache, partial
from typing import Any, Optional, Tuple, Union from typing import Any, Optional, Tuple, Union
import torch import torch
...@@ -18,11 +18,16 @@ _is_cuda = is_cuda() ...@@ -18,11 +18,16 @@ _is_cuda = is_cuda()
if _is_cuda: if _is_cuda:
from sgl_kernel.flash_attn import flash_attn_varlen_func from sgl_kernel.flash_attn import flash_attn_varlen_func
from sglang.srt.distributed import parallel_state from sglang.srt.distributed import (
parallel_state,
split_tensor_along_last_dim,
tensor_model_parallel_all_gather,
)
from sglang.srt.distributed import utils as dist_utils from sglang.srt.distributed import utils as dist_utils
from sglang.srt.layers.attention.triton_ops.prefill_attention import ( from sglang.srt.layers.attention.triton_ops.prefill_attention import (
context_attention_fwd, context_attention_fwd,
) )
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import ( from sglang.srt.layers.linear import (
ColumnParallelLinear, ColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
...@@ -349,25 +354,44 @@ class VisionAttention(nn.Module): ...@@ -349,25 +354,44 @@ class VisionAttention(nn.Module):
flatten_batch: bool = False, flatten_batch: bool = False,
prefix: str = "", prefix: str = "",
proj_bias: bool = True, proj_bias: bool = True,
num_dummy_heads: int = 0,
qkv_bias: bool = True,
qk_normalization: bool = False,
layer_norm_eps: float = 1e-06,
**kwargs, **kwargs,
): ):
super().__init__() super().__init__()
world_size = parallel_state.get_tensor_model_parallel_world_size() world_size = parallel_state.get_tensor_model_parallel_world_size()
self.tp_size = world_size
self.tp_rank = parallel_state.get_tensor_model_parallel_rank()
self.dropout = dropout self.dropout = dropout
self.head_size = embed_dim // num_heads self.head_size = embed_dim // num_heads
self.hidden_size_per_attention_head = dist_utils.divide( self.hidden_size_per_attention_head = dist_utils.divide(
projection_size, num_heads projection_size, num_heads
) )
self.num_attention_heads_per_partition = dist_utils.divide( self.num_attention_heads_per_partition = dist_utils.divide(
num_heads, world_size num_dummy_heads + num_heads, world_size
) )
self.num_attention_kv_heads_per_partition = dist_utils.divide( self.num_attention_kv_heads_per_partition = dist_utils.divide(
num_heads, world_size num_dummy_heads + num_heads, world_size
) )
self.q_size = self.num_attention_heads_per_partition * self.head_size self.q_size = self.num_attention_heads_per_partition * self.head_size
self.kv_size = self.num_attention_kv_heads_per_partition * self.head_size self.kv_size = self.num_attention_kv_heads_per_partition * self.head_size
self.qk_normalization = qk_normalization
# Additional dummy heads are used to enable TP for common GPU counts.
self.dummy_dim = (num_dummy_heads + num_heads) * self.head_size
if self.qk_normalization:
self.q_norm = RMSNorm(
self.dummy_dim, eps=layer_norm_eps, var_hidden_size=embed_dim
)
self.k_norm = RMSNorm(
self.dummy_dim, eps=layer_norm_eps, var_hidden_size=embed_dim
)
if global_server_args_dict["mm_attention_backend"] is None: if global_server_args_dict["mm_attention_backend"] is None:
if qkv_backend is None: if qkv_backend is None:
qkv_backend = "sdpa" qkv_backend = "sdpa"
...@@ -391,26 +415,46 @@ class VisionAttention(nn.Module): ...@@ -391,26 +415,46 @@ class VisionAttention(nn.Module):
self.qkv_proj = QKVParallelLinear( self.qkv_proj = QKVParallelLinear(
hidden_size=embed_dim, hidden_size=embed_dim,
head_size=self.head_size, head_size=self.head_size,
total_num_heads=num_heads, total_num_heads=num_dummy_heads + num_heads,
total_num_kv_heads=num_heads, total_num_kv_heads=num_dummy_heads + num_heads,
bias=qkv_bias,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("qkv_proj", prefix), prefix=add_prefix("qkv_proj", prefix),
) )
else: else:
self.qkv_proj = ColumnParallelLinear( self.qkv_proj = ColumnParallelLinear(
input_size=embed_dim, input_size=embed_dim,
output_size=3 * projection_size, output_size=3 * self.dummy_dim,
bias=qkv_bias,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("qkv_proj", prefix), prefix=add_prefix("qkv_proj", prefix),
) )
self.proj = RowParallelLinear( self.proj = RowParallelLinear(
input_size=embed_dim, input_size=self.dummy_dim,
output_size=embed_dim, output_size=embed_dim,
bias=proj_bias, bias=proj_bias,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("proj", prefix), prefix=add_prefix("proj", prefix),
) )
def _apply_qk_norm(self, q: torch.Tensor, k: torch.Tensor):
"""apply qk norm for internvl vit attn"""
q = q.flatten(1, 2)
k = k.flatten(1, 2)
if self.tp_size > 1:
q = tensor_model_parallel_all_gather(q.contiguous())
k = tensor_model_parallel_all_gather(k.contiguous())
q = self.q_norm(q)
k = self.k_norm(k)
if self.tp_size > 1:
splitter = partial(split_tensor_along_last_dim, num_partitions=self.tp_size)
q = splitter(q)[self.tp_rank]
k = splitter(k)[self.tp_rank]
q = q.unflatten(-1, (-1, self.head_size))
k = k.unflatten(-1, (-1, self.head_size))
return q, k
def forward( def forward(
self, self,
x: torch.Tensor, x: torch.Tensor,
...@@ -489,6 +533,10 @@ class VisionAttention(nn.Module): ...@@ -489,6 +533,10 @@ class VisionAttention(nn.Module):
assert k.dim() == 3, k.dim() assert k.dim() == 3, k.dim()
assert v.dim() == 3, v.dim() assert v.dim() == 3, v.dim()
# internvl
if self.qk_normalization:
q, k = self._apply_qk_norm(q, k)
output = self.qkv_backend.forward( output = self.qkv_backend.forward(
q=q, q=q,
k=k, k=k,
......
...@@ -61,10 +61,15 @@ class RMSNorm(CustomOp): ...@@ -61,10 +61,15 @@ class RMSNorm(CustomOp):
self, self,
hidden_size: int, hidden_size: int,
eps: float = 1e-6, eps: float = 1e-6,
var_hidden_size: Optional[int] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size)) self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps self.variance_epsilon = eps
self.hidden_size = hidden_size
self.variance_size_override = (
None if var_hidden_size == hidden_size else var_hidden_size
)
if _use_aiter: if _use_aiter:
self._forward_method = self.forward_aiter self._forward_method = self.forward_aiter
...@@ -73,6 +78,8 @@ class RMSNorm(CustomOp): ...@@ -73,6 +78,8 @@ class RMSNorm(CustomOp):
x: torch.Tensor, x: torch.Tensor,
residual: Optional[torch.Tensor] = None, residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if self.variance_size_override is not None:
return self.forward_native(x, residual)
if residual is not None: if residual is not None:
fused_add_rmsnorm(x, residual, self.weight.data, self.variance_epsilon) fused_add_rmsnorm(x, residual, self.weight.data, self.variance_epsilon)
return x, residual return x, residual
...@@ -138,7 +145,25 @@ class RMSNorm(CustomOp): ...@@ -138,7 +145,25 @@ class RMSNorm(CustomOp):
x = x + residual.to(torch.float32) x = x + residual.to(torch.float32)
residual = x.to(orig_dtype) residual = x.to(orig_dtype)
variance = x.pow(2).mean(dim=-1, keepdim=True) hidden_size = x.shape[-1]
if hidden_size != self.hidden_size:
raise ValueError(
"Expected hidden_size to be "
f"{self.hidden_size}, but found: {hidden_size}"
)
if self.variance_size_override is None:
x_var = x
else:
if hidden_size < self.variance_size_override:
raise ValueError(
"Expected hidden_size to be at least "
f"{self.variance_size_override}, but found: {hidden_size}"
)
x_var = x[..., : self.variance_size_override]
variance = x_var.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + self.variance_epsilon) x = x * torch.rsqrt(variance + self.variance_epsilon)
x = (x * self.weight).to(orig_dtype) x = (x * self.weight).to(orig_dtype)
if residual is None: if residual is None:
......
from typing import Iterable, List, Optional, Set, Tuple
import torch
from torch import nn
from transformers import PretrainedConfig
from sglang.srt.distributed import parallel_state
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternTokenPairs,
general_mm_embed_routine,
)
from sglang.srt.managers.schedule_batch import (
Modality,
MultimodalDataItem,
MultimodalInputs,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.internvl import InternVisionModel
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
from sglang.srt.models.qwen3_moe import Qwen3MoeForCausalLM
from sglang.utils import logger
class InternS1ForConditionalGeneration(nn.Module):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
use_flash_attn=True,
) -> None:
super().__init__()
self.config = config
self.quant_config = quant_config
self._update_hf_config()
image_size = (
getattr(config, "force_image_size", None) or config.vision_config.image_size
)
patch_size = config.vision_config.patch_size
if isinstance(image_size, list):
image_size = image_size[0]
if isinstance(patch_size, list):
patch_size = patch_size[0]
self.patch_size = patch_size
self.select_layer = config.vision_feature_layer
self.num_image_token = int(
(image_size // patch_size) ** 2 * (config.downsample_ratio**2)
)
self.downsample_ratio = config.downsample_ratio
self.ps_version = getattr(config, "ps_version", "v1")
# self.template = getattr(config, 'template', 'internvl2_5')
config.vision_config.use_flash_attn = True if use_flash_attn else False
config.text_config._attn_implementation = (
"flash_attention_2" if use_flash_attn else "eager"
)
logger.info(f"num_image_token: {self.num_image_token}")
logger.info(f"ps_version: {self.ps_version}")
self.vision_model = InternVisionModel(config.vision_config)
if config.text_config.architectures[0] == "Qwen2ForCausalLM":
self.language_model = Qwen2ForCausalLM(
config=config.text_config, quant_config=quant_config
)
elif config.text_config.architectures[0] == "Qwen3MoeForCausalLM":
self.language_model = Qwen3MoeForCausalLM(
config=config.text_config, quant_config=quant_config
)
else:
raise NotImplementedError(
f"{config.text_config.architectures[0]} is not implemented."
)
vit_hidden_size = config.vision_config.hidden_size
llm_hidden_size = config.text_config.hidden_size
self.mlp1 = nn.Sequential(
nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2),
nn.Linear(
vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size
),
nn.GELU(),
nn.Linear(llm_hidden_size, llm_hidden_size),
)
def _update_hf_config(self):
"""update hf config to support tp"""
world_size = parallel_state.get_tensor_model_parallel_world_size()
num_heads = self.config.vision_config.num_attention_heads
head_dim = self.config.vision_config.hidden_size // num_heads
num_dummy_heads = 0
if num_heads % world_size != 0:
num_dummy_heads = (
(num_heads + world_size) // world_size
) * world_size - num_heads
setattr(self.config.vision_config, "head_dim", head_dim)
setattr(self.config.vision_config, "num_dummy_heads", num_dummy_heads)
def pixel_shuffle(self, x, scale_factor=0.5):
n, w, h, c = x.size()
# N, W, H, C --> N, W, H * scale, C // scale
x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
# N, W, H * scale, C // scale --> N, H * scale, W, C // scale
x = x.permute(0, 2, 1, 3).contiguous()
# N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2)
x = x.view(
n,
int(h * scale_factor),
int(w * scale_factor),
int(c / (scale_factor * scale_factor)),
)
if self.ps_version == "v1":
logger.warn(
"In ps_version 'v1', the height and width have not been swapped back, "
"which results in a transposed image."
)
else:
x = x.permute(0, 2, 1, 3).contiguous()
return x
def extract_feature(self, pixel_values):
if self.select_layer == -1:
vit_embeds = self.vision_model(
pixel_values=pixel_values, output_hidden_states=False, return_dict=True
).last_hidden_state
else:
vit_embeds = self.vision_model(
pixel_values=pixel_values, output_hidden_states=True, return_dict=True
).hidden_states[self.select_layer]
vit_embeds = vit_embeds[:, 1:, :]
h = w = int(vit_embeds.shape[1] ** 0.5)
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio)
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
vit_embeds = self.mlp1(vit_embeds)
return vit_embeds
def get_image_feature(self, items: List[MultimodalDataItem]):
"""
Projects the last hidden state from the vision model into language model space.
Returns:
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
"""
pixel_values = torch.cat([item.feature for item in items])
image_features = self.extract_feature(pixel_values)
return image_features
@torch.no_grad()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
hs = general_mm_embed_routine(
input_ids=input_ids,
forward_batch=forward_batch,
language_model=self.language_model,
data_embedding_funcs={
Modality.IMAGE: self.get_image_feature,
},
positions=positions,
)
return hs
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
# Get all special token IDs
im_start_id: int = mm_inputs.im_start_id
im_end_id: int = mm_inputs.im_end_id
media_token_pairs = [(im_start_id, im_end_id)]
helper = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
return helper.pad_input_tokens(input_ids, mm_inputs)
def _pad_vit_attn_dummy_heads(self, name: str, loaded_weight: torch.Tensor):
"""pad attn qkv weights for dummy heads"""
num_dummy_heads = self.config.vision_config.num_dummy_heads
if num_dummy_heads == 0:
return loaded_weight
head_dim = self.config.vision_config.head_dim
if any([_ in name for _ in ["attn.q_proj", "attn.k_proj", "attn.v_proj"]]):
if name.endswith(".weight"):
dummy_shape = [num_dummy_heads, head_dim, loaded_weight.shape[-1]]
elif name.endswith(".bias"):
dummy_shape = [num_dummy_heads, head_dim]
else:
raise RuntimeError(f"Unsupported weight with name={name}")
padded_weight = loaded_weight.new_zeros(dummy_shape)
loaded_weight = torch.cat(
[loaded_weight.unflatten(0, (-1, head_dim)), padded_weight], dim=0
).flatten(0, 1)
if "attn.proj.weight" in name:
padded_weight = loaded_weight.new_zeros(
loaded_weight.shape[0], head_dim * num_dummy_heads
)
loaded_weight = torch.cat([loaded_weight, padded_weight], dim=-1)
if "attn.q_norm.weight" in name or "attn.k_norm.weight" in name:
padded_weight = loaded_weight.new_zeros(head_dim * num_dummy_heads)
loaded_weight = torch.cat([loaded_weight, padded_weight], dim=0)
return loaded_weight
def _mapping_interns1_name(self, name):
names_map = {
"lm_head.weight": "language_model.lm_head.weight",
"model.multi_modal_projector.layer_norm.bias": "mlp1.0.bias",
"model.multi_modal_projector.layer_norm.weight": "mlp1.0.weight",
"model.multi_modal_projector.linear_1.bias": "mlp1.1.bias",
"model.multi_modal_projector.linear_1.weight": "mlp1.1.weight",
"model.multi_modal_projector.linear_2.bias": "mlp1.3.bias",
"model.multi_modal_projector.linear_2.weight": "mlp1.3.weight",
"model.vision_tower.embeddings.cls_token": "vision_model.embeddings.class_embedding",
"model.vision_tower.embeddings.patch_embeddings.projection.bias": "vision_model.embeddings.patch_embedding.bias",
"model.vision_tower.embeddings.patch_embeddings.projection.weight": "vision_model.embeddings.patch_embedding.weight",
"model.vision_tower.embeddings.position_embeddings": "vision_model.embeddings.position_embedding",
}
if name in names_map:
name = names_map[name]
elif name.startswith("model.language_model."):
name = "language_model.model." + name[len("model.language_model.") :]
elif name.startswith("model.vision_tower."):
name = "vision_model." + name[len("model.vision_tower.") :]
if name.startswith("vision_model.encoder.layer"):
name = name.replace(r".layer.", r".layers.")
name = name.replace(r".attention.", r".attn.attn.")
name = name.replace(r".projection_layer.", r".proj.")
name = name.replace(r".lambda_1", r".ls1")
name = name.replace(r".lambda_2", r".ls2")
name = name.replace(r".layernorm_before.", r".norm1.")
name = name.replace(r".layernorm_after.", r".norm2.")
return name
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
expert_params_mapping = []
if "Qwen3MoeForCausalLM" in self.config.text_config.architectures:
expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.num_experts,
)
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
name = self._mapping_interns1_name(name)
if "vision_model" in name:
loaded_weight = self._pad_vit_attn_dummy_heads(name, loaded_weight)
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if "mlp.experts" in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(
param,
loaded_weight,
name,
shard_id=shard_id,
expert_id=expert_id,
)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
loaded_params.add(name)
unloaded_params = params_dict.keys() - loaded_params
if unloaded_params:
raise RuntimeError(
f"Some weights are not initialized from checkpoints: {unloaded_params}"
)
return loaded_params
EntryClass = [InternS1ForConditionalGeneration]
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==========================582====================================================
from typing import Iterable, List, Optional, Set, Tuple, Union from typing import Iterable, List, Optional, Set, Tuple, Union
import torch import torch
...@@ -23,7 +10,9 @@ from transformers import PretrainedConfig, PreTrainedModel ...@@ -23,7 +10,9 @@ from transformers import PretrainedConfig, PreTrainedModel
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from sglang.srt.distributed import parallel_state
from sglang.srt.layers.attention.vision import SingletonCache, VisionAttention from sglang.srt.layers.attention.vision import SingletonCache, VisionAttention
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.managers.mm_utils import ( from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternTokenPairs, MultiModalityDataPaddingPatternTokenPairs,
...@@ -39,6 +28,7 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader ...@@ -39,6 +28,7 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.deepseek_janus_pro import DropPath from sglang.srt.models.deepseek_janus_pro import DropPath
from sglang.srt.models.internlm2 import InternLM2ForCausalLM from sglang.srt.models.internlm2 import InternLM2ForCausalLM
from sglang.srt.models.qwen2 import Qwen2ForCausalLM from sglang.srt.models.qwen2 import Qwen2ForCausalLM
from sglang.srt.models.qwen3_moe import Qwen3MoeForCausalLM
from sglang.utils import logger from sglang.utils import logger
...@@ -53,7 +43,6 @@ class InternAttention(nn.Module): ...@@ -53,7 +43,6 @@ class InternAttention(nn.Module):
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads self.head_dim = self.embed_dim // self.num_heads
self.scale = self.head_dim**-0.5 self.scale = self.head_dim**-0.5
self.attn = VisionAttention( self.attn = VisionAttention(
...@@ -64,18 +53,16 @@ class InternAttention(nn.Module): ...@@ -64,18 +53,16 @@ class InternAttention(nn.Module):
use_qkv_parallel=True, use_qkv_parallel=True,
quant_config=quant_config, quant_config=quant_config,
dropout=getattr(config, "dropout", 0.0), dropout=getattr(config, "dropout", 0.0),
proj_bias=getattr(config, "qkv_bias", True), qkv_bias=getattr(config, "qkv_bias", False)
or getattr(config, "attention_bias", False),
num_dummy_heads=getattr(config, "num_dummy_heads", 0),
qk_normalization=getattr(config, "qk_normalization", False)
or getattr(config, "use_qk_norm", False),
flatten_batch=False, flatten_batch=False,
) )
self.proj_drop = nn.Dropout(config.dropout) self.proj_drop = nn.Dropout(config.dropout)
self.qk_normalization = config.qk_normalization
if self.qk_normalization:
self.q_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
self.k_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
...@@ -91,8 +78,16 @@ class InternVisionEmbeddings(nn.Module): ...@@ -91,8 +78,16 @@ class InternVisionEmbeddings(nn.Module):
super().__init__() super().__init__()
self.config = config self.config = config
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
self.image_size = config.image_size self.image_size = (
self.patch_size = config.patch_size config.image_size
if isinstance(config.image_size, int)
else config.image_size[0]
)
self.patch_size = (
config.patch_size
if isinstance(config.patch_size, int)
else config.patch_size[0]
)
self.class_embedding = nn.Parameter( self.class_embedding = nn.Parameter(
torch.randn(1, 1, self.embed_dim), torch.randn(1, 1, self.embed_dim),
...@@ -199,7 +194,7 @@ class InternVisionEncoderLayer(nn.Module): ...@@ -199,7 +194,7 @@ class InternVisionEncoderLayer(nn.Module):
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
self.intermediate_size = config.intermediate_size self.intermediate_size = config.intermediate_size
self.norm_type = config.norm_type self.norm_type = config.norm_type
self.attn = InternAttention(config) self.attn = InternAttention(config=config, quant_config=quant_config)
self.mlp = InternMLP(config) self.mlp = InternMLP(config)
self.norm1 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps) self.norm1 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps)
self.norm2 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps) self.norm2 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps)
...@@ -417,7 +412,7 @@ class InternVLChatModel(nn.Module): ...@@ -417,7 +412,7 @@ class InternVLChatModel(nn.Module):
super().__init__() super().__init__()
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self._update_vision_config()
image_size = config.force_image_size or config.vision_config.image_size image_size = config.force_image_size or config.vision_config.image_size
patch_size = config.vision_config.patch_size patch_size = config.vision_config.patch_size
self.patch_size = patch_size self.patch_size = patch_size
...@@ -446,6 +441,10 @@ class InternVLChatModel(nn.Module): ...@@ -446,6 +441,10 @@ class InternVLChatModel(nn.Module):
self.language_model = InternLM2ForCausalLM( self.language_model = InternLM2ForCausalLM(
config=config.llm_config, quant_config=quant_config config=config.llm_config, quant_config=quant_config
) )
elif config.llm_config.architectures[0] == "Qwen3MoeForCausalLM":
self.language_model = Qwen3MoeForCausalLM(
config=config.llm_config, quant_config=quant_config
)
else: else:
raise NotImplementedError( raise NotImplementedError(
f"{config.llm_config.architectures[0]} is not implemented." f"{config.llm_config.architectures[0]} is not implemented."
...@@ -463,6 +462,21 @@ class InternVLChatModel(nn.Module): ...@@ -463,6 +462,21 @@ class InternVLChatModel(nn.Module):
nn.Linear(llm_hidden_size, llm_hidden_size), nn.Linear(llm_hidden_size, llm_hidden_size),
) )
def _update_vision_config(self):
"""update vision config to support tp"""
world_size = parallel_state.get_tensor_model_parallel_world_size()
num_heads = self.config.vision_config.num_attention_heads
head_dim = self.config.vision_config.hidden_size // num_heads
num_dummy_heads = 0
if num_heads % world_size != 0:
num_dummy_heads = (
(num_heads + world_size) // world_size
) * world_size - num_heads
setattr(self.config.vision_config, "head_dim", head_dim)
setattr(self.config.vision_config, "num_dummy_heads", num_dummy_heads)
def pixel_shuffle(self, x, scale_factor=0.5): def pixel_shuffle(self, x, scale_factor=0.5):
n, w, h, c = x.size() n, w, h, c = x.size()
# N, W, H, C --> N, W, H * scale, C // scale # N, W, H, C --> N, W, H * scale, C // scale
...@@ -545,7 +559,38 @@ class InternVLChatModel(nn.Module): ...@@ -545,7 +559,38 @@ class InternVLChatModel(nn.Module):
return helper.pad_input_tokens(input_ids, mm_inputs) return helper.pad_input_tokens(input_ids, mm_inputs)
def _pad_vit_attn_dummy_heads(self, name: str, loaded_weight: torch.Tensor):
"""pad attn qkv weights for dummy heads"""
num_dummy_heads = self.config.vision_config.num_dummy_heads
if num_dummy_heads == 0:
return loaded_weight
head_dim = self.config.vision_config.head_dim
if "attn.qkv_proj" in name:
wq, wk, wv = loaded_weight.chunk(3, dim=0)
if name.endswith(".weight"):
dummy_shape = [num_dummy_heads, head_dim, wq.shape[-1]]
elif name.endswith(".bias"):
dummy_shape = [num_dummy_heads, head_dim]
else:
raise RuntimeError(f"Unsupported weight with name={name}")
pad_func = lambda x: torch.cat(
[x.unflatten(0, (-1, head_dim)), x.new_zeros(dummy_shape)], dim=0
).flatten(0, 1)
wq, wk, wv = pad_func(wq), pad_func(wk), pad_func(wv)
loaded_weight = torch.cat([wq, wk, wv], dim=0)
if "attn.proj.weight" in name:
padded_weight = loaded_weight.new_zeros(
loaded_weight.shape[0], head_dim * num_dummy_heads
)
loaded_weight = torch.cat([loaded_weight, padded_weight], dim=-1)
if "attn.q_norm.weight" in name or "attn.k_norm.weight" in name:
padded_weight = loaded_weight.new_zeros(head_dim * num_dummy_heads)
loaded_weight = torch.cat([loaded_weight, padded_weight], dim=0)
return loaded_weight
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
expert_params_mapping = []
if "InternLM2ForCausalLM" in self.config.llm_config.architectures: if "InternLM2ForCausalLM" in self.config.llm_config.architectures:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
...@@ -561,15 +606,41 @@ class InternVLChatModel(nn.Module): ...@@ -561,15 +606,41 @@ class InternVLChatModel(nn.Module):
("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1), ("gate_up_proj", "up_proj", 1),
] ]
elif "Qwen3MoeForCausalLM" in self.config.llm_config.architectures:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.num_experts,
)
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set() loaded_params: Set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
for param_name, weight_name, shard_id in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
continue continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if "mlp.experts" in name:
continue
name = name.replace(weight_name, param_name) name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
...@@ -584,30 +655,55 @@ class InternVLChatModel(nn.Module): ...@@ -584,30 +655,55 @@ class InternVLChatModel(nn.Module):
name = name.replace(r"attn.", r"attn.attn.") name = name.replace(r"attn.", r"attn.attn.")
name = name.replace(r"qkv.", r"qkv_proj.") name = name.replace(r"qkv.", r"qkv_proj.")
# Skip loading extra bias for GPTQ models. for mapping in expert_params_mapping:
if name.endswith(".bias") and name not in params_dict: param_name, weight_name, expert_id, shard_id = mapping
continue if weight_name not in name:
param = params_dict[name] continue
if "wqkv" in name: name = name.replace(weight_name, param_name)
config = self.config param = params_dict[name]
kv_groups = config.num_attention_heads // config.num_key_value_heads
head_dim = config.hidden_size // config.num_attention_heads
loaded_weight = loaded_weight.view(
-1, 2 + kv_groups, head_dim, loaded_weight.shape[-1]
)
wq, wk, wv = torch.split(loaded_weight, [kv_groups, 1, 1], dim=1)
wq = wq.reshape(-1, wq.shape[-1])
wk = wk.reshape(-1, wk.shape[-1])
wv = wv.reshape(-1, wv.shape[-1])
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, wq, "q") weight_loader(
weight_loader(param, wk, "k") param,
weight_loader(param, wv, "v") loaded_weight,
else: name,
weight_loader = getattr( shard_id=shard_id,
param, "weight_loader", default_weight_loader expert_id=expert_id,
) )
weight_loader(param, loaded_weight) break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
if "wqkv" in name:
config = self.config
kv_groups = (
config.num_attention_heads // config.num_key_value_heads
)
head_dim = config.hidden_size // config.num_attention_heads
loaded_weight = loaded_weight.view(
-1, 2 + kv_groups, head_dim, loaded_weight.shape[-1]
)
wq, wk, wv = torch.split(
loaded_weight, [kv_groups, 1, 1], dim=1
)
wq = wq.reshape(-1, wq.shape[-1])
wk = wk.reshape(-1, wk.shape[-1])
wv = wv.reshape(-1, wv.shape[-1])
weight_loader = param.weight_loader
weight_loader(param, wq, "q")
weight_loader(param, wk, "k")
weight_loader(param, wv, "v")
else:
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
if "vision_model" in name:
loaded_weight = self._pad_vit_attn_dummy_heads(
name, loaded_weight
)
weight_loader(param, loaded_weight)
loaded_params.add(name) loaded_params.add(name)
unloaded_params = params_dict.keys() - loaded_params unloaded_params = params_dict.keys() - loaded_params
if unloaded_params: if unloaded_params:
......
...@@ -707,6 +707,9 @@ class Qwen3MoeForCausalLM(nn.Module): ...@@ -707,6 +707,9 @@ class Qwen3MoeForCausalLM(nn.Module):
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.capture_aux_hidden_states = False self.capture_aux_hidden_states = False
def get_input_embeddings(self) -> nn.Embedding:
return self.model.embed_tokens
@torch.no_grad() @torch.no_grad()
def forward( def forward(
self, self,
......
...@@ -6,6 +6,7 @@ from decord import VideoReader, cpu ...@@ -6,6 +6,7 @@ from decord import VideoReader, cpu
from PIL import Image from PIL import Image
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
from sglang.srt.models.interns1 import InternS1ForConditionalGeneration
from sglang.srt.models.internvl import InternVLChatModel from sglang.srt.models.internvl import InternVLChatModel
from sglang.srt.multimodal.processors.base_processor import ( from sglang.srt.multimodal.processors.base_processor import (
BaseMultimodalProcessor, BaseMultimodalProcessor,
...@@ -14,12 +15,19 @@ from sglang.srt.multimodal.processors.base_processor import ( ...@@ -14,12 +15,19 @@ from sglang.srt.multimodal.processors.base_processor import (
class InternVLImageProcessor(BaseMultimodalProcessor): class InternVLImageProcessor(BaseMultimodalProcessor):
models = [InternVLChatModel] models = [InternVLChatModel, InternS1ForConditionalGeneration]
def __init__(self, hf_config, server_args, _image_processor, *args, **kwargs): def __init__(self, hf_config, server_args, _image_processor, *args, **kwargs):
super().__init__(hf_config, server_args, _image_processor, *args, **kwargs) super().__init__(hf_config, server_args, _image_processor, *args, **kwargs)
image_size = hf_config.force_image_size or hf_config.vision_config.image_size image_size = (
getattr(hf_config, "force_image_size", None)
or hf_config.vision_config.image_size
)
patch_size = hf_config.vision_config.patch_size patch_size = hf_config.vision_config.patch_size
if isinstance(image_size, list):
image_size = image_size[0]
if isinstance(patch_size, list):
patch_size = patch_size[0]
self.IMG_CONTEXT_TOKEN = "<IMG_CONTEXT>" self.IMG_CONTEXT_TOKEN = "<IMG_CONTEXT>"
self.IMG_START_TOKEN = "<img>" self.IMG_START_TOKEN = "<img>"
...@@ -27,8 +35,12 @@ class InternVLImageProcessor(BaseMultimodalProcessor): ...@@ -27,8 +35,12 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
self.num_image_token = int( self.num_image_token = int(
(image_size // patch_size) ** 2 * (hf_config.downsample_ratio**2) (image_size // patch_size) ** 2 * (hf_config.downsample_ratio**2)
) )
if hasattr(self._processor, "tokenizer"):
tokenizer = self._processor.tokenizer
else:
tokenizer = self._processor
self.tokenizer = tokenizer
tokenizer = self._processor
self.img_start_token_id = tokenizer.convert_tokens_to_ids(self.IMG_START_TOKEN) self.img_start_token_id = tokenizer.convert_tokens_to_ids(self.IMG_START_TOKEN)
self.img_end_token_id = tokenizer.convert_tokens_to_ids(self.IMG_END_TOKEN) self.img_end_token_id = tokenizer.convert_tokens_to_ids(self.IMG_END_TOKEN)
self.mm_tokens = MultimodalSpecialTokens( self.mm_tokens = MultimodalSpecialTokens(
...@@ -195,7 +207,7 @@ class InternVLImageProcessor(BaseMultimodalProcessor): ...@@ -195,7 +207,7 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
try: try:
# TODO: video input # TODO: video input
raw_image = process_image_internvl(image) raw_image = process_image_internvl(image)
pixel_value = [raw_image.to(torch.bfloat16).cuda()] pixel_value = [raw_image.to(torch.bfloat16)]
pixel_values += pixel_value pixel_values += pixel_value
num_patches = raw_image.shape[0] num_patches = raw_image.shape[0]
num_patches_list += [num_patches] num_patches_list += [num_patches]
...@@ -214,8 +226,9 @@ class InternVLImageProcessor(BaseMultimodalProcessor): ...@@ -214,8 +226,9 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
) )
input_text = input_text.replace("<image>", image_tokens, 1) input_text = input_text.replace("<image>", image_tokens, 1)
tokenizer = self._processor input_ids = self.tokenizer(input_text, return_tensors="pt")[
input_ids = tokenizer(input_text, return_tensors="pt")["input_ids"].flatten() "input_ids"
].flatten()
image_offsets = self.get_mm_items_offset( image_offsets = self.get_mm_items_offset(
input_ids=input_ids, input_ids=input_ids,
mm_token_id=self.mm_tokens.image_token_id, mm_token_id=self.mm_tokens.image_token_id,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment