Unverified Commit b7094a5e authored by RunningLeon's avatar RunningLeon Committed by GitHub
Browse files

model: support intern-s1 (#8350)


Signed-off-by: default avatarXinyuan Tong <xinyuantong.cs@gmail.com>
Co-authored-by: default avatarzxy <zhou0493@e.ntu.edu.sg>
Co-authored-by: default avatarXinyuan Tong <xinyuantong.cs@gmail.com>
Co-authored-by: default avatarMick <mickjagger19@icloud.com>
Co-authored-by: default avatarXinyuan Tong <115166877+JustinTong0323@users.noreply.github.com>
parent da0c0260
......@@ -448,6 +448,19 @@ register_chat_template(
)
)
register_chat_template(
ChatTemplate(
name="interns1",
default_system_prompt="You are an AI assistant whose name is Intern-S1 (书生大模型).\n- Intern-S1 (书生大模型) is a vision-language model that is developed by Shanghai AI Laboratory (上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n- Intern-S1 (书生大模型) can understand and communicate fluently in the language chosen by the user such as English and 中文.\nYou are an expert reasoner with extensive experience in all areas. You approach problems through systematic thinking and rigorous reasoning. Your response should reflect deep understanding and precise logical thinking, making your solution path and reasoning clear to others. Please put your thinking process within <think>...</think> tags.",
role_prefix_and_suffix={
"system": ("<|im_start|>system\n", "<|im_end|>\n"),
"user": ("<|im_start|>user\n", "<|im_end|>\n"),
"assistant": ("<|im_start|>assistant\n", "<|im_end|>\n"),
},
stop_str=["<|im_end|>", "<|action_end|>"],
)
)
register_chat_template(
ChatTemplate(
name="granite-3-instruct",
......@@ -609,6 +622,14 @@ def match_internvl_chat(model_path: str):
return "internvl-2-5"
@register_chat_template_matching_function
def match_interns1_chat(model_path: str):
if re.search(r"intern-s1", model_path, re.IGNORECASE):
return "interns1"
if re.search(r"interns1", model_path, re.IGNORECASE):
return "interns1"
if __name__ == "__main__":
messages = [
{"role": "system", "content": None}, # None means default
......
......@@ -10,6 +10,7 @@ from transformers import (
PretrainedConfig,
PreTrainedTokenizer,
Qwen2Config,
Qwen3Config,
)
from sglang.utils import logger
......@@ -314,6 +315,8 @@ class InternVLChatConfig(PretrainedConfig):
self.llm_config = InternLM2Config(**llm_config)
elif llm_config.get("architectures")[0] == "Qwen2ForCausalLM":
self.llm_config = Qwen2Config(**llm_config)
elif llm_config.get("architectures")[0] == "Qwen3MoeForCausalLM":
self.llm_config = Qwen3Config(**llm_config)
else:
raise ValueError(
"Unsupported architecture: {}".format(
......
......@@ -635,6 +635,7 @@ multimodal_model_archs = [
"Qwen2_5_VLForConditionalGeneration",
"KimiVLForConditionalGeneration",
"InternVLChatModel",
"InternS1ForConditionalGeneration",
"Phi4MMForCausalLM",
"VILAForConditionalGeneration",
]
......
......@@ -623,7 +623,7 @@ def generate_chat_conv(
real_content += content.text
elif content.type == "image_url":
# NOTE: works for llava and intervl2_5
if conv.name == "internvl-2-5":
if conv.name in ["internvl-2-5", "interns1"]:
real_content = image_token + real_content
else:
real_content += image_token
......@@ -817,6 +817,19 @@ register_conv_template(
)
)
register_conv_template(
Conversation(
name="interns1",
system_template="<|im_start|>system\n{system_message}",
system_message="You are an AI assistant whose name is Intern-S1 (书生大模型).\n- Intern-S1 (书生大模型) is a vision-language model that is developed by Shanghai AI Laboratory (上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n- Intern-S1 (书生大模型) can understand and communicate fluently in the language chosen by the user such as English and 中文.\nYou are an expert reasoner with extensive experience in all areas. You approach problems through systematic thinking and rigorous reasoning. Your response should reflect deep understanding and precise logical thinking, making your solution path and reasoning clear to others. Please put your thinking process within <think>...</think> tags.",
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
sep_style=SeparatorStyle.MPT,
sep="<|im_end|>\n",
stop_str=["<|im_end|>", "<|action_end|>"],
image_token="<image>",
)
)
# Reference: https://huggingface.co/docs/transformers/main/model_doc/qwen2_vl#usage-example
register_conv_template(
Conversation(
......@@ -986,6 +999,8 @@ register_conv_template(
def match_internvl(model_path: str):
if re.search(r"internvl", model_path, re.IGNORECASE):
return "internvl-2-5"
if re.search(r"interns1", model_path, re.IGNORECASE):
return "interns1"
@register_conv_template_matching_function
......
......@@ -3,7 +3,7 @@ from __future__ import annotations
import dataclasses
import functools
import math
from functools import lru_cache
from functools import lru_cache, partial
from typing import Any, Optional, Tuple, Union
import torch
......@@ -18,11 +18,16 @@ _is_cuda = is_cuda()
if _is_cuda:
from sgl_kernel.flash_attn import flash_attn_varlen_func
from sglang.srt.distributed import parallel_state
from sglang.srt.distributed import (
parallel_state,
split_tensor_along_last_dim,
tensor_model_parallel_all_gather,
)
from sglang.srt.distributed import utils as dist_utils
from sglang.srt.layers.attention.triton_ops.prefill_attention import (
context_attention_fwd,
)
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (
ColumnParallelLinear,
QKVParallelLinear,
......@@ -349,25 +354,44 @@ class VisionAttention(nn.Module):
flatten_batch: bool = False,
prefix: str = "",
proj_bias: bool = True,
num_dummy_heads: int = 0,
qkv_bias: bool = True,
qk_normalization: bool = False,
layer_norm_eps: float = 1e-06,
**kwargs,
):
super().__init__()
world_size = parallel_state.get_tensor_model_parallel_world_size()
self.tp_size = world_size
self.tp_rank = parallel_state.get_tensor_model_parallel_rank()
self.dropout = dropout
self.head_size = embed_dim // num_heads
self.hidden_size_per_attention_head = dist_utils.divide(
projection_size, num_heads
)
self.num_attention_heads_per_partition = dist_utils.divide(
num_heads, world_size
num_dummy_heads + num_heads, world_size
)
self.num_attention_kv_heads_per_partition = dist_utils.divide(
num_heads, world_size
num_dummy_heads + num_heads, world_size
)
self.q_size = self.num_attention_heads_per_partition * self.head_size
self.kv_size = self.num_attention_kv_heads_per_partition * self.head_size
self.qk_normalization = qk_normalization
# Additional dummy heads are used to enable TP for common GPU counts.
self.dummy_dim = (num_dummy_heads + num_heads) * self.head_size
if self.qk_normalization:
self.q_norm = RMSNorm(
self.dummy_dim, eps=layer_norm_eps, var_hidden_size=embed_dim
)
self.k_norm = RMSNorm(
self.dummy_dim, eps=layer_norm_eps, var_hidden_size=embed_dim
)
if global_server_args_dict["mm_attention_backend"] is None:
if qkv_backend is None:
qkv_backend = "sdpa"
......@@ -391,26 +415,46 @@ class VisionAttention(nn.Module):
self.qkv_proj = QKVParallelLinear(
hidden_size=embed_dim,
head_size=self.head_size,
total_num_heads=num_heads,
total_num_kv_heads=num_heads,
total_num_heads=num_dummy_heads + num_heads,
total_num_kv_heads=num_dummy_heads + num_heads,
bias=qkv_bias,
quant_config=quant_config,
prefix=add_prefix("qkv_proj", prefix),
)
else:
self.qkv_proj = ColumnParallelLinear(
input_size=embed_dim,
output_size=3 * projection_size,
output_size=3 * self.dummy_dim,
bias=qkv_bias,
quant_config=quant_config,
prefix=add_prefix("qkv_proj", prefix),
)
self.proj = RowParallelLinear(
input_size=embed_dim,
input_size=self.dummy_dim,
output_size=embed_dim,
bias=proj_bias,
quant_config=quant_config,
prefix=add_prefix("proj", prefix),
)
def _apply_qk_norm(self, q: torch.Tensor, k: torch.Tensor):
"""apply qk norm for internvl vit attn"""
q = q.flatten(1, 2)
k = k.flatten(1, 2)
if self.tp_size > 1:
q = tensor_model_parallel_all_gather(q.contiguous())
k = tensor_model_parallel_all_gather(k.contiguous())
q = self.q_norm(q)
k = self.k_norm(k)
if self.tp_size > 1:
splitter = partial(split_tensor_along_last_dim, num_partitions=self.tp_size)
q = splitter(q)[self.tp_rank]
k = splitter(k)[self.tp_rank]
q = q.unflatten(-1, (-1, self.head_size))
k = k.unflatten(-1, (-1, self.head_size))
return q, k
def forward(
self,
x: torch.Tensor,
......@@ -489,6 +533,10 @@ class VisionAttention(nn.Module):
assert k.dim() == 3, k.dim()
assert v.dim() == 3, v.dim()
# internvl
if self.qk_normalization:
q, k = self._apply_qk_norm(q, k)
output = self.qkv_backend.forward(
q=q,
k=k,
......
......@@ -61,10 +61,15 @@ class RMSNorm(CustomOp):
self,
hidden_size: int,
eps: float = 1e-6,
var_hidden_size: Optional[int] = None,
) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
self.hidden_size = hidden_size
self.variance_size_override = (
None if var_hidden_size == hidden_size else var_hidden_size
)
if _use_aiter:
self._forward_method = self.forward_aiter
......@@ -73,6 +78,8 @@ class RMSNorm(CustomOp):
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if self.variance_size_override is not None:
return self.forward_native(x, residual)
if residual is not None:
fused_add_rmsnorm(x, residual, self.weight.data, self.variance_epsilon)
return x, residual
......@@ -138,7 +145,25 @@ class RMSNorm(CustomOp):
x = x + residual.to(torch.float32)
residual = x.to(orig_dtype)
variance = x.pow(2).mean(dim=-1, keepdim=True)
hidden_size = x.shape[-1]
if hidden_size != self.hidden_size:
raise ValueError(
"Expected hidden_size to be "
f"{self.hidden_size}, but found: {hidden_size}"
)
if self.variance_size_override is None:
x_var = x
else:
if hidden_size < self.variance_size_override:
raise ValueError(
"Expected hidden_size to be at least "
f"{self.variance_size_override}, but found: {hidden_size}"
)
x_var = x[..., : self.variance_size_override]
variance = x_var.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + self.variance_epsilon)
x = (x * self.weight).to(orig_dtype)
if residual is None:
......
from typing import Iterable, List, Optional, Set, Tuple
import torch
from torch import nn
from transformers import PretrainedConfig
from sglang.srt.distributed import parallel_state
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternTokenPairs,
general_mm_embed_routine,
)
from sglang.srt.managers.schedule_batch import (
Modality,
MultimodalDataItem,
MultimodalInputs,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.internvl import InternVisionModel
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
from sglang.srt.models.qwen3_moe import Qwen3MoeForCausalLM
from sglang.utils import logger
class InternS1ForConditionalGeneration(nn.Module):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
use_flash_attn=True,
) -> None:
super().__init__()
self.config = config
self.quant_config = quant_config
self._update_hf_config()
image_size = (
getattr(config, "force_image_size", None) or config.vision_config.image_size
)
patch_size = config.vision_config.patch_size
if isinstance(image_size, list):
image_size = image_size[0]
if isinstance(patch_size, list):
patch_size = patch_size[0]
self.patch_size = patch_size
self.select_layer = config.vision_feature_layer
self.num_image_token = int(
(image_size // patch_size) ** 2 * (config.downsample_ratio**2)
)
self.downsample_ratio = config.downsample_ratio
self.ps_version = getattr(config, "ps_version", "v1")
# self.template = getattr(config, 'template', 'internvl2_5')
config.vision_config.use_flash_attn = True if use_flash_attn else False
config.text_config._attn_implementation = (
"flash_attention_2" if use_flash_attn else "eager"
)
logger.info(f"num_image_token: {self.num_image_token}")
logger.info(f"ps_version: {self.ps_version}")
self.vision_model = InternVisionModel(config.vision_config)
if config.text_config.architectures[0] == "Qwen2ForCausalLM":
self.language_model = Qwen2ForCausalLM(
config=config.text_config, quant_config=quant_config
)
elif config.text_config.architectures[0] == "Qwen3MoeForCausalLM":
self.language_model = Qwen3MoeForCausalLM(
config=config.text_config, quant_config=quant_config
)
else:
raise NotImplementedError(
f"{config.text_config.architectures[0]} is not implemented."
)
vit_hidden_size = config.vision_config.hidden_size
llm_hidden_size = config.text_config.hidden_size
self.mlp1 = nn.Sequential(
nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2),
nn.Linear(
vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size
),
nn.GELU(),
nn.Linear(llm_hidden_size, llm_hidden_size),
)
def _update_hf_config(self):
"""update hf config to support tp"""
world_size = parallel_state.get_tensor_model_parallel_world_size()
num_heads = self.config.vision_config.num_attention_heads
head_dim = self.config.vision_config.hidden_size // num_heads
num_dummy_heads = 0
if num_heads % world_size != 0:
num_dummy_heads = (
(num_heads + world_size) // world_size
) * world_size - num_heads
setattr(self.config.vision_config, "head_dim", head_dim)
setattr(self.config.vision_config, "num_dummy_heads", num_dummy_heads)
def pixel_shuffle(self, x, scale_factor=0.5):
n, w, h, c = x.size()
# N, W, H, C --> N, W, H * scale, C // scale
x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
# N, W, H * scale, C // scale --> N, H * scale, W, C // scale
x = x.permute(0, 2, 1, 3).contiguous()
# N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2)
x = x.view(
n,
int(h * scale_factor),
int(w * scale_factor),
int(c / (scale_factor * scale_factor)),
)
if self.ps_version == "v1":
logger.warn(
"In ps_version 'v1', the height and width have not been swapped back, "
"which results in a transposed image."
)
else:
x = x.permute(0, 2, 1, 3).contiguous()
return x
def extract_feature(self, pixel_values):
if self.select_layer == -1:
vit_embeds = self.vision_model(
pixel_values=pixel_values, output_hidden_states=False, return_dict=True
).last_hidden_state
else:
vit_embeds = self.vision_model(
pixel_values=pixel_values, output_hidden_states=True, return_dict=True
).hidden_states[self.select_layer]
vit_embeds = vit_embeds[:, 1:, :]
h = w = int(vit_embeds.shape[1] ** 0.5)
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio)
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
vit_embeds = self.mlp1(vit_embeds)
return vit_embeds
def get_image_feature(self, items: List[MultimodalDataItem]):
"""
Projects the last hidden state from the vision model into language model space.
Returns:
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
"""
pixel_values = torch.cat([item.feature for item in items])
image_features = self.extract_feature(pixel_values)
return image_features
@torch.no_grad()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
hs = general_mm_embed_routine(
input_ids=input_ids,
forward_batch=forward_batch,
language_model=self.language_model,
data_embedding_funcs={
Modality.IMAGE: self.get_image_feature,
},
positions=positions,
)
return hs
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
# Get all special token IDs
im_start_id: int = mm_inputs.im_start_id
im_end_id: int = mm_inputs.im_end_id
media_token_pairs = [(im_start_id, im_end_id)]
helper = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
return helper.pad_input_tokens(input_ids, mm_inputs)
def _pad_vit_attn_dummy_heads(self, name: str, loaded_weight: torch.Tensor):
"""pad attn qkv weights for dummy heads"""
num_dummy_heads = self.config.vision_config.num_dummy_heads
if num_dummy_heads == 0:
return loaded_weight
head_dim = self.config.vision_config.head_dim
if any([_ in name for _ in ["attn.q_proj", "attn.k_proj", "attn.v_proj"]]):
if name.endswith(".weight"):
dummy_shape = [num_dummy_heads, head_dim, loaded_weight.shape[-1]]
elif name.endswith(".bias"):
dummy_shape = [num_dummy_heads, head_dim]
else:
raise RuntimeError(f"Unsupported weight with name={name}")
padded_weight = loaded_weight.new_zeros(dummy_shape)
loaded_weight = torch.cat(
[loaded_weight.unflatten(0, (-1, head_dim)), padded_weight], dim=0
).flatten(0, 1)
if "attn.proj.weight" in name:
padded_weight = loaded_weight.new_zeros(
loaded_weight.shape[0], head_dim * num_dummy_heads
)
loaded_weight = torch.cat([loaded_weight, padded_weight], dim=-1)
if "attn.q_norm.weight" in name or "attn.k_norm.weight" in name:
padded_weight = loaded_weight.new_zeros(head_dim * num_dummy_heads)
loaded_weight = torch.cat([loaded_weight, padded_weight], dim=0)
return loaded_weight
def _mapping_interns1_name(self, name):
names_map = {
"lm_head.weight": "language_model.lm_head.weight",
"model.multi_modal_projector.layer_norm.bias": "mlp1.0.bias",
"model.multi_modal_projector.layer_norm.weight": "mlp1.0.weight",
"model.multi_modal_projector.linear_1.bias": "mlp1.1.bias",
"model.multi_modal_projector.linear_1.weight": "mlp1.1.weight",
"model.multi_modal_projector.linear_2.bias": "mlp1.3.bias",
"model.multi_modal_projector.linear_2.weight": "mlp1.3.weight",
"model.vision_tower.embeddings.cls_token": "vision_model.embeddings.class_embedding",
"model.vision_tower.embeddings.patch_embeddings.projection.bias": "vision_model.embeddings.patch_embedding.bias",
"model.vision_tower.embeddings.patch_embeddings.projection.weight": "vision_model.embeddings.patch_embedding.weight",
"model.vision_tower.embeddings.position_embeddings": "vision_model.embeddings.position_embedding",
}
if name in names_map:
name = names_map[name]
elif name.startswith("model.language_model."):
name = "language_model.model." + name[len("model.language_model.") :]
elif name.startswith("model.vision_tower."):
name = "vision_model." + name[len("model.vision_tower.") :]
if name.startswith("vision_model.encoder.layer"):
name = name.replace(r".layer.", r".layers.")
name = name.replace(r".attention.", r".attn.attn.")
name = name.replace(r".projection_layer.", r".proj.")
name = name.replace(r".lambda_1", r".ls1")
name = name.replace(r".lambda_2", r".ls2")
name = name.replace(r".layernorm_before.", r".norm1.")
name = name.replace(r".layernorm_after.", r".norm2.")
return name
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
expert_params_mapping = []
if "Qwen3MoeForCausalLM" in self.config.text_config.architectures:
expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.num_experts,
)
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
name = self._mapping_interns1_name(name)
if "vision_model" in name:
loaded_weight = self._pad_vit_attn_dummy_heads(name, loaded_weight)
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if "mlp.experts" in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(
param,
loaded_weight,
name,
shard_id=shard_id,
expert_id=expert_id,
)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
loaded_params.add(name)
unloaded_params = params_dict.keys() - loaded_params
if unloaded_params:
raise RuntimeError(
f"Some weights are not initialized from checkpoints: {unloaded_params}"
)
return loaded_params
EntryClass = [InternS1ForConditionalGeneration]
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==========================582====================================================
from typing import Iterable, List, Optional, Set, Tuple, Union
import torch
......@@ -23,7 +10,9 @@ from transformers import PretrainedConfig, PreTrainedModel
from transformers.activations import ACT2FN
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from sglang.srt.distributed import parallel_state
from sglang.srt.layers.attention.vision import SingletonCache, VisionAttention
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternTokenPairs,
......@@ -39,6 +28,7 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.deepseek_janus_pro import DropPath
from sglang.srt.models.internlm2 import InternLM2ForCausalLM
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
from sglang.srt.models.qwen3_moe import Qwen3MoeForCausalLM
from sglang.utils import logger
......@@ -53,7 +43,6 @@ class InternAttention(nn.Module):
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
self.scale = self.head_dim**-0.5
self.attn = VisionAttention(
......@@ -64,18 +53,16 @@ class InternAttention(nn.Module):
use_qkv_parallel=True,
quant_config=quant_config,
dropout=getattr(config, "dropout", 0.0),
proj_bias=getattr(config, "qkv_bias", True),
qkv_bias=getattr(config, "qkv_bias", False)
or getattr(config, "attention_bias", False),
num_dummy_heads=getattr(config, "num_dummy_heads", 0),
qk_normalization=getattr(config, "qk_normalization", False)
or getattr(config, "use_qk_norm", False),
flatten_batch=False,
)
self.proj_drop = nn.Dropout(config.dropout)
self.qk_normalization = config.qk_normalization
if self.qk_normalization:
self.q_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
self.k_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
......@@ -91,8 +78,16 @@ class InternVisionEmbeddings(nn.Module):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.image_size = config.image_size
self.patch_size = config.patch_size
self.image_size = (
config.image_size
if isinstance(config.image_size, int)
else config.image_size[0]
)
self.patch_size = (
config.patch_size
if isinstance(config.patch_size, int)
else config.patch_size[0]
)
self.class_embedding = nn.Parameter(
torch.randn(1, 1, self.embed_dim),
......@@ -199,7 +194,7 @@ class InternVisionEncoderLayer(nn.Module):
self.embed_dim = config.hidden_size
self.intermediate_size = config.intermediate_size
self.norm_type = config.norm_type
self.attn = InternAttention(config)
self.attn = InternAttention(config=config, quant_config=quant_config)
self.mlp = InternMLP(config)
self.norm1 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps)
self.norm2 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps)
......@@ -417,7 +412,7 @@ class InternVLChatModel(nn.Module):
super().__init__()
self.config = config
self.quant_config = quant_config
self._update_vision_config()
image_size = config.force_image_size or config.vision_config.image_size
patch_size = config.vision_config.patch_size
self.patch_size = patch_size
......@@ -446,6 +441,10 @@ class InternVLChatModel(nn.Module):
self.language_model = InternLM2ForCausalLM(
config=config.llm_config, quant_config=quant_config
)
elif config.llm_config.architectures[0] == "Qwen3MoeForCausalLM":
self.language_model = Qwen3MoeForCausalLM(
config=config.llm_config, quant_config=quant_config
)
else:
raise NotImplementedError(
f"{config.llm_config.architectures[0]} is not implemented."
......@@ -463,6 +462,21 @@ class InternVLChatModel(nn.Module):
nn.Linear(llm_hidden_size, llm_hidden_size),
)
def _update_vision_config(self):
"""update vision config to support tp"""
world_size = parallel_state.get_tensor_model_parallel_world_size()
num_heads = self.config.vision_config.num_attention_heads
head_dim = self.config.vision_config.hidden_size // num_heads
num_dummy_heads = 0
if num_heads % world_size != 0:
num_dummy_heads = (
(num_heads + world_size) // world_size
) * world_size - num_heads
setattr(self.config.vision_config, "head_dim", head_dim)
setattr(self.config.vision_config, "num_dummy_heads", num_dummy_heads)
def pixel_shuffle(self, x, scale_factor=0.5):
n, w, h, c = x.size()
# N, W, H, C --> N, W, H * scale, C // scale
......@@ -545,7 +559,38 @@ class InternVLChatModel(nn.Module):
return helper.pad_input_tokens(input_ids, mm_inputs)
def _pad_vit_attn_dummy_heads(self, name: str, loaded_weight: torch.Tensor):
"""pad attn qkv weights for dummy heads"""
num_dummy_heads = self.config.vision_config.num_dummy_heads
if num_dummy_heads == 0:
return loaded_weight
head_dim = self.config.vision_config.head_dim
if "attn.qkv_proj" in name:
wq, wk, wv = loaded_weight.chunk(3, dim=0)
if name.endswith(".weight"):
dummy_shape = [num_dummy_heads, head_dim, wq.shape[-1]]
elif name.endswith(".bias"):
dummy_shape = [num_dummy_heads, head_dim]
else:
raise RuntimeError(f"Unsupported weight with name={name}")
pad_func = lambda x: torch.cat(
[x.unflatten(0, (-1, head_dim)), x.new_zeros(dummy_shape)], dim=0
).flatten(0, 1)
wq, wk, wv = pad_func(wq), pad_func(wk), pad_func(wv)
loaded_weight = torch.cat([wq, wk, wv], dim=0)
if "attn.proj.weight" in name:
padded_weight = loaded_weight.new_zeros(
loaded_weight.shape[0], head_dim * num_dummy_heads
)
loaded_weight = torch.cat([loaded_weight, padded_weight], dim=-1)
if "attn.q_norm.weight" in name or "attn.k_norm.weight" in name:
padded_weight = loaded_weight.new_zeros(head_dim * num_dummy_heads)
loaded_weight = torch.cat([loaded_weight, padded_weight], dim=0)
return loaded_weight
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
expert_params_mapping = []
if "InternLM2ForCausalLM" in self.config.llm_config.architectures:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
......@@ -561,15 +606,41 @@ class InternVLChatModel(nn.Module):
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
elif "Qwen3MoeForCausalLM" in self.config.llm_config.architectures:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.num_experts,
)
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if "mlp.experts" in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
......@@ -584,30 +655,55 @@ class InternVLChatModel(nn.Module):
name = name.replace(r"attn.", r"attn.attn.")
name = name.replace(r"qkv.", r"qkv_proj.")
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
if "wqkv" in name:
config = self.config
kv_groups = config.num_attention_heads // config.num_key_value_heads
head_dim = config.hidden_size // config.num_attention_heads
loaded_weight = loaded_weight.view(
-1, 2 + kv_groups, head_dim, loaded_weight.shape[-1]
)
wq, wk, wv = torch.split(loaded_weight, [kv_groups, 1, 1], dim=1)
wq = wq.reshape(-1, wq.shape[-1])
wk = wk.reshape(-1, wk.shape[-1])
wv = wv.reshape(-1, wv.shape[-1])
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, wq, "q")
weight_loader(param, wk, "k")
weight_loader(param, wv, "v")
else:
weight_loader = getattr(
param, "weight_loader", default_weight_loader
weight_loader(
param,
loaded_weight,
name,
shard_id=shard_id,
expert_id=expert_id,
)
weight_loader(param, loaded_weight)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
if "wqkv" in name:
config = self.config
kv_groups = (
config.num_attention_heads // config.num_key_value_heads
)
head_dim = config.hidden_size // config.num_attention_heads
loaded_weight = loaded_weight.view(
-1, 2 + kv_groups, head_dim, loaded_weight.shape[-1]
)
wq, wk, wv = torch.split(
loaded_weight, [kv_groups, 1, 1], dim=1
)
wq = wq.reshape(-1, wq.shape[-1])
wk = wk.reshape(-1, wk.shape[-1])
wv = wv.reshape(-1, wv.shape[-1])
weight_loader = param.weight_loader
weight_loader(param, wq, "q")
weight_loader(param, wk, "k")
weight_loader(param, wv, "v")
else:
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
if "vision_model" in name:
loaded_weight = self._pad_vit_attn_dummy_heads(
name, loaded_weight
)
weight_loader(param, loaded_weight)
loaded_params.add(name)
unloaded_params = params_dict.keys() - loaded_params
if unloaded_params:
......
......@@ -707,6 +707,9 @@ class Qwen3MoeForCausalLM(nn.Module):
self.logits_processor = LogitsProcessor(config)
self.capture_aux_hidden_states = False
def get_input_embeddings(self) -> nn.Embedding:
return self.model.embed_tokens
@torch.no_grad()
def forward(
self,
......
......@@ -6,6 +6,7 @@ from decord import VideoReader, cpu
from PIL import Image
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
from sglang.srt.models.interns1 import InternS1ForConditionalGeneration
from sglang.srt.models.internvl import InternVLChatModel
from sglang.srt.multimodal.processors.base_processor import (
BaseMultimodalProcessor,
......@@ -14,12 +15,19 @@ from sglang.srt.multimodal.processors.base_processor import (
class InternVLImageProcessor(BaseMultimodalProcessor):
models = [InternVLChatModel]
models = [InternVLChatModel, InternS1ForConditionalGeneration]
def __init__(self, hf_config, server_args, _image_processor, *args, **kwargs):
super().__init__(hf_config, server_args, _image_processor, *args, **kwargs)
image_size = hf_config.force_image_size or hf_config.vision_config.image_size
image_size = (
getattr(hf_config, "force_image_size", None)
or hf_config.vision_config.image_size
)
patch_size = hf_config.vision_config.patch_size
if isinstance(image_size, list):
image_size = image_size[0]
if isinstance(patch_size, list):
patch_size = patch_size[0]
self.IMG_CONTEXT_TOKEN = "<IMG_CONTEXT>"
self.IMG_START_TOKEN = "<img>"
......@@ -27,8 +35,12 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
self.num_image_token = int(
(image_size // patch_size) ** 2 * (hf_config.downsample_ratio**2)
)
if hasattr(self._processor, "tokenizer"):
tokenizer = self._processor.tokenizer
else:
tokenizer = self._processor
self.tokenizer = tokenizer
tokenizer = self._processor
self.img_start_token_id = tokenizer.convert_tokens_to_ids(self.IMG_START_TOKEN)
self.img_end_token_id = tokenizer.convert_tokens_to_ids(self.IMG_END_TOKEN)
self.mm_tokens = MultimodalSpecialTokens(
......@@ -195,7 +207,7 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
try:
# TODO: video input
raw_image = process_image_internvl(image)
pixel_value = [raw_image.to(torch.bfloat16).cuda()]
pixel_value = [raw_image.to(torch.bfloat16)]
pixel_values += pixel_value
num_patches = raw_image.shape[0]
num_patches_list += [num_patches]
......@@ -214,8 +226,9 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
)
input_text = input_text.replace("<image>", image_tokens, 1)
tokenizer = self._processor
input_ids = tokenizer(input_text, return_tensors="pt")["input_ids"].flatten()
input_ids = self.tokenizer(input_text, return_tensors="pt")[
"input_ids"
].flatten()
image_offsets = self.get_mm_items_offset(
input_ids=input_ids,
mm_token_id=self.mm_tokens.image_token_id,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment