Unverified Commit 8fefdd32 authored by liwenju0's avatar liwenju0 Committed by GitHub
Browse files

[Feature] add support kimi vl model (#5383)


Co-authored-by: default avatarwenju.li <wenju.li@deepctr.cn>
parent 403b855a
......@@ -28,4 +28,5 @@ python3 -m sglang.launch_server \
| **LLaVA** (v1.5 & v1.6) | *e.g.* `liuhaotian/llava-v1.5-13b` | `vicuna_v1.1` | Open vision-chat models that add an image encoder to LLaMA/Vicuna (e.g. LLaMA2 13B) for following multimodal instruction prompts. |
| **LLaVA-NeXT** (8B, 72B) | `lmms-lab/llava-next-72b` | `chatml-llava` | Improved LLaVA models (with an 8B Llama3 version and a 72B version) offering enhanced visual instruction-following and accuracy on multimodal benchmarks. |
| **LLaVA-OneVision** | `lmms-lab/llava-onevision-qwen2-7b-ov` | `chatml-llava` | Enhanced LLaVA variant integrating Qwen as the backbone; supports multiple images (and even video frames) as inputs via an OpenAI Vision API-compatible format. |
| **Gemma 3 (Multimodal)** | `google/gemma-3-4b-it` | `gemma-it` | Gemma 3’s larger models (4B, 12B, 27B) accept images (each image encoded as 256 tokens) alongside text in a combined 128K-token context. |
| **Gemma 3 (Multimodal)** | `google/gemma-3-4b-it` | `gemma-it` | Gemma 3’s larger models (4B, 12B, 27B) accept images (each image encoded as 256 tokens) alongside text in a combined 128K-token context. |
| **Kimi-VL** (A3B) | `moonshotai/Kimi-VL-A3B-Instruct` | `kimi-vl` | Kimi-VL is a multimodal model that can understand and generate text from images. |
......@@ -42,6 +42,7 @@ runtime_common = [
"uvicorn",
"uvloop",
"xgrammar==0.1.17",
"blobfile==3.0.0"
]
srt = [
......
......@@ -3,6 +3,8 @@ from sglang.srt.configs.dbrx import DbrxConfig
from sglang.srt.configs.deepseekvl2 import DeepseekVL2Config
from sglang.srt.configs.exaone import ExaoneConfig
from sglang.srt.configs.janus_pro import MultiModalityConfig
from sglang.srt.configs.kimi_vl import KimiVLConfig
from sglang.srt.configs.kimi_vl_moonvit import MoonViTConfig
__all__ = [
"ExaoneConfig",
......@@ -10,4 +12,6 @@ __all__ = [
"DbrxConfig",
"DeepseekVL2Config",
"MultiModalityConfig",
"KimiVLConfig",
"MoonViTConfig",
]
# SPDX-License-Identifier: Apache-2.0
# Adapted from https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct/blob/main/configuration_kimi_vl.py
from typing import Optional, Union
from transformers.configuration_utils import PretrainedConfig
from sglang.srt.configs.deepseekvl2 import DeepseekV2Config
from sglang.srt.configs.kimi_vl_moonvit import MoonViTConfig
class KimiVLConfig(PretrainedConfig):
model_type = "kimi_vl"
def __init__(
self,
vision_config: Optional[Union[dict, MoonViTConfig]] = None,
text_config: Optional[Union[dict, DeepseekV2Config]] = None,
ignore_index: int = -100,
media_placeholder_token_id: int = 163605,
pad_token_id: int = 0,
**kwargs
):
if vision_config is None:
vision_config = MoonViTConfig()
elif isinstance(vision_config, dict):
vision_config = MoonViTConfig(**vision_config)
self.vision_config = vision_config
if text_config is None:
text_config = DeepseekV2Config()
elif isinstance(text_config, dict):
text_config = DeepseekV2Config(**text_config)
self.text_config = text_config
self.ignore_index = ignore_index
self.media_placeholder_token_id = media_placeholder_token_id
super().__init__(pad_token_id=pad_token_id, **kwargs)
# SPDX-License-Identifier: Apache-2.0
# Adapted from https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct/blob/main/configuration_kimi_vl.py
from transformers.configuration_utils import PretrainedConfig
class MoonViTConfig(PretrainedConfig):
model_type = "moonvit"
def __init__(
self,
patch_size: int = 14,
init_pos_emb_height: int = 64,
init_pos_emb_width: int = 64,
num_attention_heads: int = 16,
num_hidden_layers: int = 27,
hidden_size: int = 1152,
intermediate_size: int = 4304,
merge_kernel_size: tuple[int, int] = (2, 2),
**kwargs,
):
super().__init__(**kwargs)
self.patch_size = patch_size
# Positional embedding config
self.init_pos_emb_height = init_pos_emb_height
self.init_pos_emb_width = init_pos_emb_width
# Transformer config
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
# Patch merger config
self.merge_kernel_size = merge_kernel_size
......@@ -176,6 +176,13 @@ class ModelConfig:
self.attention_arch = AttentionArch.MLA
self.kv_lora_rank = self.hf_text_config.kv_lora_rank
self.qk_rope_head_dim = self.hf_text_config.qk_rope_head_dim
elif "KimiVLForConditionalGeneration" in self.hf_config.architectures:
self.head_dim = 256
self.attention_arch = AttentionArch.MLA
self.kv_lora_rank = self.hf_text_config.kv_lora_rank
self.qk_rope_head_dim = self.hf_text_config.qk_rope_head_dim
self.v_head_dim = self.hf_text_config.v_head_dim
self.qk_nope_head_dim = self.hf_text_config.qk_nope_head_dim
else:
self.attention_arch = AttentionArch.MHA
......@@ -530,6 +537,7 @@ multimodal_model_archs = [
"Qwen2VLForConditionalGeneration",
"Qwen2_5_VLForConditionalGeneration",
"CLIPModel",
"KimiVLForConditionalGeneration",
]
......
......@@ -806,6 +806,24 @@ register_conv_template(
)
)
# Reference: https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct/blob/main/chat_template.jinja
register_conv_template(
Conversation(
name="kimi-vl",
system_message="You are a helpful assistant",
system_template="<|im_system|>system<|im_middle|>{system_message}",
roles=(
"<|im_user|>user<|im_middle|>",
"<|im_assistant|>assistant<|im_middle|>",
),
messages=[],
sep="<|im_end|>",
sep_style=SeparatorStyle.NO_COLON_SINGLE,
stop_str="<|im_end|>",
image_token="<|media_start|>image<|media_content|><|media_pad|><|media_end|>",
)
)
@register_conv_template_matching_function
def match_deepseek_janus_pro(model_path: str):
......@@ -888,3 +906,10 @@ def match_openbmb_minicpm(model_path: str):
return "minicpmv"
elif "minicpm-o" in model_path:
return "minicpmo"
@register_conv_template_matching_function
def match_moonshot_kimivl(model_path: str):
model_path = model_path.lower()
if "kimi" in model_path and "vl" in model_path:
return "kimi-vl"
......@@ -35,6 +35,7 @@ from sglang.srt.configs import (
DbrxConfig,
DeepseekVL2Config,
ExaoneConfig,
KimiVLConfig,
MultiModalityConfig,
)
from sglang.srt.connector import create_remote_connector
......@@ -46,6 +47,7 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
ExaoneConfig.model_type: ExaoneConfig,
DeepseekVL2Config.model_type: DeepseekVL2Config,
MultiModalityConfig.model_type: MultiModalityConfig,
KimiVLConfig.model_type: KimiVLConfig,
}
for name, cls in _CONFIG_REGISTRY.items():
......
import asyncio
import math
from typing import List, Union
import torch
from PIL import Image
from sglang.srt.managers.multimodal_processors.base_processor import (
BaseMultimodalProcessor as SGLangBaseProcessor,
)
from sglang.srt.managers.multimodal_processors.base_processor import (
MultimodalSpecialTokens,
)
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
from sglang.srt.models.kimi_vl import KimiVLForConditionalGeneration
# Compatible with KimiVLForConditionalGeneration
class KimiVLImageProcessor(SGLangBaseProcessor):
models = [KimiVLForConditionalGeneration]
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
self.IMAGE_TOKEN = "<|media_pad|>"
self.im_token_id = _processor.tokenizer.convert_tokens_to_ids(self.IMAGE_TOKEN)
self.im_start = "<|media_start|>"
self.im_start_id = _processor.tokenizer.convert_tokens_to_ids(self.im_start)
self.im_end = "<|media_end|>"
self.im_end_id = _processor.tokenizer.convert_tokens_to_ids(self.im_end)
self.im_content = "<|media_content|>"
self.im_content_id = _processor.tokenizer.convert_tokens_to_ids(self.im_content)
async def process_mm_data_async(
self,
image_data: List[Union[str, bytes]],
input_text,
request_obj,
max_req_input_len,
*args,
**kwargs,
):
if not image_data:
return None
if isinstance(image_data, str):
image_data = [image_data]
base_output = self.load_mm_data(
prompt=input_text,
image_data=image_data,
multimodal_tokens=MultimodalSpecialTokens(image_token=self.IMAGE_TOKEN),
max_req_input_len=max_req_input_len,
)
ret = self.process_mm_data(
input_text=base_output.input_text,
images=base_output.images,
)
return {
"input_ids": ret["input_ids"].flatten().tolist(),
"mm_items": [
MultimodalDataItem(
pixel_values=ret["pixel_values"],
image_grid_thws=ret["image_grid_hws"],
modality=Modality.IMAGE,
)
],
"im_token_id": self.im_token_id,
"im_start_id": self.im_start_id,
"im_end_id": self.im_end_id,
"im_content_id": self.im_content_id,
}
......@@ -752,7 +752,7 @@ class DeepseekV2AttentionMLA(nn.Module):
q_nope_out = q_nope_out.transpose(0, 1)
k_nope = latent_cache[..., : self.kv_lora_rank]
k_nope = self.kv_a_layernorm(k_nope).unsqueeze(1)
k_nope = self.kv_a_layernorm(k_nope.contiguous()).unsqueeze(1)
k_pe = latent_cache[..., self.kv_lora_rank :].unsqueeze(1)
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
......@@ -1391,6 +1391,9 @@ class DeepseekV2Model(nn.Module):
self.dp_size = get_attention_dp_size()
def get_input_embeddings(self) -> torch.Tensor:
return self.embed_tokens
def forward(
self,
input_ids: torch.Tensor,
......
# SPDX-License-Identifier: Apache-2.0
# ruff: noqa: E501
# Adapted from https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct/blob/main/modeling_kimi_vl.py
# Copyright 2025 The Moonshot AI Team, DeepSeek-AI, and HuggingFace Inc. team. All rights reserved.
#
# The code is based on llava (llava/modeling_llava.py) and DeepSeek-V3 (DeepSeek-V3/modeling_deepseek.py), but modified for KimiVL.
#
# Licensing Information:
# - Code derived from llava (llava/modeling_llava.py) and DeepSeek-V3 (DeepSeek-V3/modeling_deepseek.py) is licensed under the Apache License, Version 2.0.
# - Other parts of the code are licensed under the MIT License.
#
# Apache License, Version 2.0:
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# MIT License:
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import copy
import logging
import math
from collections.abc import Mapping
from dataclasses import dataclass
from typing import Any, Iterable, List, Optional, Tuple
import torch
from torch import nn
from transformers.activations import GELUActivation
from sglang.srt.configs import KimiVLConfig
from sglang.srt.configs.deepseekvl2 import DeepseekV2Config
from sglang.srt.configs.kimi_vl import KimiVLConfig
from sglang.srt.configs.kimi_vl_moonvit import MoonViTConfig
from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from sglang.srt.layers.activation import QuickGELU
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternMultimodalTokens,
general_mm_embed_routine,
)
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import (
default_weight_loader,
maybe_remap_kv_scale_name,
)
from sglang.srt.models.deepseek_v2 import DeepseekV2ForCausalLM
from sglang.srt.models.kimi_vl_moonvit import MoonVitPretrainedModel
from sglang.srt.utils import add_prefix
logger = logging.getLogger(__name__)
# For dummy input only
@dataclass
class MaxImageTokenMeta:
width: int = 1024
height: int = 1024
class KimiVLMultiModalProjector(nn.Module):
def __init__(self, config: KimiVLConfig):
super().__init__()
self.hidden_size = (
config.vision_config.hidden_size
* config.vision_config.merge_kernel_size[0]
* config.vision_config.merge_kernel_size[1]
)
self.pre_norm = torch.nn.LayerNorm(config.vision_config.hidden_size, eps=1e-5)
self.linear_1 = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
self.act = GELUActivation()
self.act = QuickGELU()
self.linear_2 = nn.Linear(
self.hidden_size, config.text_config.hidden_size, bias=True
)
def forward(self, image_features: torch.Tensor) -> torch.Tensor:
hidden_states = self.pre_norm(image_features).view(-1, self.hidden_size)
hidden_states = self.linear_1(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
class KimiVLForConditionalGeneration(nn.Module):
def __init__(
self,
config: KimiVLConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
**kwargs, # fix init_tts argument error
) -> None:
super().__init__()
self.config = config
assert isinstance(config.vision_config, MoonViTConfig)
self.vision_tower = MoonVitPretrainedModel(config.vision_config)
self.multi_modal_projector = KimiVLMultiModalProjector(config=config)
self.quant_config = quant_config
text_config = copy.deepcopy(config.text_config)
text_config.architectures = ["DeepseekV2ForCausalLM"]
self.language_model = DeepseekV2ForCausalLM(
config=text_config,
quant_config=quant_config,
prefix=add_prefix("language_model", prefix),
)
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
pixel_values = (
torch.cat([item.pixel_values for item in items], dim=0)
.type(self.vision_tower.dtype)
.to(self.vision_tower.device)
)
image_grid_thws = torch.concat(
[item.image_grid_thws for item in items], dim=0
).to(self.vision_tower.device)
image_features = self.vision_tower(pixel_values, image_grid_thws)
assert isinstance(image_features, list)
# lengths = [x.shape[0] for x in image_features]
res = self.multi_modal_projector(torch.cat(image_features)) # .split(lengths)
return res
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
# Get all special token IDs
pattern = MultiModalityDataPaddingPatternMultimodalTokens(mm_inputs.im_token_id)
return pattern.pad_input_tokens(input_ids, mm_inputs)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
get_embedding: bool = False,
):
hidden_states = general_mm_embed_routine(
input_ids=input_ids,
forward_batch=forward_batch,
language_model=self.language_model,
image_data_embedding_func=self.get_image_feature,
positions=positions,
)
return hidden_states
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
config = self.config.text_config
_KEYS_TO_MODIFY_MAPPING = {
# "language_model.lm_head": "lm_head",
# "language_model.model": "language_model",
}
# only doing this for language model part for now.
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
if not config.use_mla:
stacked_params_mapping += [
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
]
if getattr(config, "n_routed_experts", None):
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=config.n_routed_experts,
)
else:
expert_params_mapping = []
params_dict = dict(self.named_parameters())
for args in weights:
name, loaded_weight = args[:2]
kwargs = args[2] if len(args) > 2 else {}
if "rotary_emb.inv_freq" in name:
continue
spec_layer = get_spec_layer_idx_from_weight_name(config, name)
if spec_layer is not None:
continue # skip spec decode layers for main model
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
if key_to_modify in name:
name = name.replace(key_to_modify, new_key)
use_default_weight_loading = False
if "vision" in name:
if self.vision_tower is not None:
# We only do sharding for language model and
# not vision model for now.
use_default_weight_loading = True
else:
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if ("mlp.experts." in name) and name not in params_dict:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id, **kwargs)
break
else:
for idx, (
param_name,
weight_name,
expert_id,
shard_id,
) in enumerate(expert_params_mapping):
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(
param,
loaded_weight,
name,
expert_id=expert_id,
shard_id=shard_id,
**kwargs,
)
break
else:
use_default_weight_loading = True
if use_default_weight_loading:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
# if is_pp_missing_parameter(name, self):
# continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight, **kwargs)
self.language_model.post_load_weights()
def get_spec_layer_idx_from_weight_name(
config: DeepseekV2Config, weight_name: str
) -> Optional[int]:
if hasattr(config, "num_nextn_predict_layers") and (
config.num_nextn_predict_layers > 0
):
layer_idx = config.num_hidden_layers
for i in range(config.num_nextn_predict_layers):
if weight_name.startswith(f"model.layers.{layer_idx+i}."):
return layer_idx + i
return None
EntryClass = [KimiVLForConditionalGeneration]
# SPDX-License-Identifier: Apache-2.0
# ruff: noqa: E501
# Adapted from https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct/blob/main/modeling_kimi_vl.py
# This file is meant to be used in kimi_vl.py only
# Copyright 2025 The Moonshot AI Team, DeepSeek-AI, and HuggingFace Inc. team. All rights reserved.
#
# The code is based on llava (llava/modeling_llava.py) and DeepSeek-V3 (DeepSeek-V3/modeling_deepseek.py), but modified for KimiVL.
#
# Licensing Information:
# - Code derived from llava (llava/modeling_llava.py) and DeepSeek-V3 (DeepSeek-V3/modeling_deepseek.py) is licensed under the Apache License, Version 2.0.
# - Other parts of the code are licensed under the MIT License.
#
# Apache License, Version 2.0:
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# MIT License:
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import math
from copy import deepcopy
from functools import cached_property
from typing import List, Optional, Sequence, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.activations import ACT2FN, PytorchGELUTanh
from transformers.modeling_utils import PreTrainedModel
try:
from flash_attn.flash_attn_interface import flash_attn_varlen_func
except ImportError:
flash_attn_varlen_func = None
from sglang.srt.configs import MoonViTConfig
def multihead_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
q_cu_seqlens: Optional[torch.Tensor] = None,
k_cu_seqlens: Optional[torch.Tensor] = None,
):
"""Multi-head attention using flash attention 2.
This function is used to handle the case where the query, key, and value are packed.
Args:
q, k, v: tensor of shape (tot_seqlens, num_heads, head_dim).
q_cu_seqlens (torch.Tensor): cumulative sequence lengths of q.
The first element should be 0 and the last element should be q.shape[0].
k_cu_seqlens (torch.Tensor): cumulative sequence lengths of k.
The first element should be 0 and the last element should be k.shape[0].
Returns:
output: shape (batch_size, seqlen, dim) or (tot_seqlens, dim) if packing,
where dim = num_heads * head_dim
"""
if flash_attn_varlen_func is None:
raise ImportError(
"flash_attn is not installed, this function needs flash_attn_varlen_func from flash_attn"
)
# Unified format legal check
assert q.dim() == k.dim() == v.dim() == 3, "q, k, v must have 3 dims"
assert q_cu_seqlens[-1] == q.shape[0], "q_cu_seqlens must sum to q.shape[0]"
assert (
k_cu_seqlens[-1] == k.shape[0] == v.shape[0]
), "k_cu_seqlens must sum to k.shape[0]"
assert q.dtype in [
torch.bfloat16,
torch.float16,
], f"unsupported dtype {q.dtype} for multihead attn"
max_seqlen_q = (q_cu_seqlens[1:] - q_cu_seqlens[:-1]).max().item()
max_seqlen_k = (k_cu_seqlens[1:] - k_cu_seqlens[:-1]).max().item()
attn_out = flash_attn_varlen_func(
q,
k,
v,
q_cu_seqlens,
k_cu_seqlens,
max_seqlen_q,
max_seqlen_k,
causal=False,
)
attn_out = attn_out.flatten(start_dim=-2)
return attn_out
def sdpa_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
q_cu_seqlens: Optional[torch.Tensor] = None,
k_cu_seqlens: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Multi-head attention using torch scaled dot product attention.
This function is used to handle the case where the query, key, and value are packed.
Args:
q, k, v: tensor of shape (tot_seqlens, num_heads, head_dim).
q_cu_seqlens (torch.Tensor): cumulative sequence lengths of q.
The first element should be 0 and the last element should be q.shape[0].
k_cu_seqlens (torch.Tensor): cumulative sequence lengths of k.
The first element should be 0 and the last element should be k.shape[0].
Returns:
output: shape (batch_size, seqlen, dim) or (tot_seqlens, dim) if packing,
where dim = num_heads * head_dim
"""
# Unified format legal check
assert q.dim() == k.dim() == v.dim() == 3, "q, k, v must have 3 dims"
assert q_cu_seqlens[-1] == q.shape[0], "q_cu_seqlens must sum to q.shape[0]"
seq_length = q.shape[0]
attention_mask = torch.zeros(
[1, seq_length, seq_length], device=q.device, dtype=torch.bool
)
for i in range(1, len(q_cu_seqlens)):
attention_mask[
...,
q_cu_seqlens[i - 1] : q_cu_seqlens[i],
q_cu_seqlens[i - 1] : q_cu_seqlens[i],
] = True
q = q.transpose(0, 1)
k = k.transpose(0, 1)
v = v.transpose(0, 1)
attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0)
attn_output = attn_output.transpose(0, 1)
attn_output = attn_output.reshape(seq_length, -1)
return attn_output
VL_VISION_ATTENTION_FUNCTIONS = {
"flash_attention_2": multihead_attention,
"sdpa": sdpa_attention,
}
def _apply_rope_input_validation(x, freqs_cis):
assert x.ndim == freqs_cis.ndim + 1, (x.shape, freqs_cis.shape)
assert x.shape[:-2] == freqs_cis.shape[:-1], (x.shape, freqs_cis.shape)
assert x.shape[-1] == 2 * freqs_cis.shape[-1], (x.shape, freqs_cis.shape)
assert freqs_cis.dtype == torch.complex64, freqs_cis.dtype
def apply_rope(
xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Args: (The leading dimensions of all inputs should be the same)
xq: query, tensor of shape (..., num_heads, head_dim)
xk: key, tensor of shape (..., num_heads, head_dim)
freqs_cis: tensor of shape (..., head_dim/2), dtype=torch.complex64. It contains the precomputed cis(freqs) for each position in the 2D grid.
Returns:
xq_out, xk_out: tensors of shape (..., num_heads, head_dim)
"""
_apply_rope_input_validation(xq, freqs_cis)
_apply_rope_input_validation(xk, freqs_cis)
freqs_cis = freqs_cis.unsqueeze(-2) # ..., 1, head_dim/2
# ..., num_heads, head_dim/2
xq_ = torch.view_as_complex(xq.float().view(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().view(*xq.shape[:-1], -1, 2))
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(-2) # ..., num_heads, head_dim
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(-2) # ..., num_heads, head_dim
return xq_out.type_as(xq), xk_out.type_as(xk)
class Learnable2DInterpPosEmb(nn.Module):
def __init__(
self, height: int, width: int, dim: int, interpolation_mode: str = "bicubic"
) -> None:
super().__init__()
self.height = height
self.width = width
self.interpolation_mode = interpolation_mode
self.weight = nn.Parameter(torch.empty(height, width, dim))
self.reset_parameters()
def reset_parameters(self):
nn.init.normal_(self.weight)
def forward(self, x: torch.Tensor, grid_hws: torch.Tensor) -> torch.Tensor:
pos_embs = []
for shape in grid_hws.tolist():
if shape == self.weight.shape[:-1]:
pos_embs.append(self.weight.flatten(end_dim=1))
else:
pos_embs.append(
F.interpolate(
self.weight.permute((2, 0, 1)).unsqueeze(0),
size=shape,
mode=self.interpolation_mode,
)
.squeeze(0)
.permute((1, 2, 0))
.flatten(end_dim=1)
)
out = x + torch.cat(pos_embs)
return out
class MoonVisionPatchEmbed(nn.Module):
def __init__(
self,
out_dim: int,
in_dim: int = 3,
patch_size: Union[int, Tuple[int, int]] = (14, 14),
pos_emb_height: int = 14,
pos_emb_width: int = 14,
):
super().__init__()
assert isinstance(
patch_size, (int, Sequence)
), f"Invalid patch_size type: {type(patch_size)}"
if isinstance(patch_size, int):
patch_size = (patch_size, patch_size)
assert (
len(patch_size) == 2
), f"Expected patch_size to be a tuple of 2, got {patch_size}"
self.patch_size = patch_size
self.proj = nn.Conv2d(
in_dim, out_dim, kernel_size=patch_size, stride=patch_size
)
self.pos_emb = Learnable2DInterpPosEmb(
height=pos_emb_height, width=pos_emb_width, dim=out_dim
)
def forward(self, x: torch.Tensor, grid_hw: torch.Tensor) -> torch.Tensor:
"""
Args:
x (L, Channels): input tensor
grid_hw (N, 2): grid height and width
Returns:
(L, Cout) tensor
"""
x = self.proj(x).view(x.size(0), -1)
# apply positional embedding
x = self.pos_emb(x, grid_hw)
return x
class Rope2DPosEmb(nn.Module):
"""2D rotary position embedding with multi-resolution support.
This class is intended to be used in the following way:
1. Before training, create an instance of Rope2DPosEmb. This instance will hold the precomputed cis.
2. Before each forward pass, call `get_freqs_cis_by_*` to get the `freqs_cis` tensor for this iteration.
3. During the forward pass, pass the `freqs_cis` tensor to each attention layer, and call `apply` just before each attention operation.
The rope is shared across all attention layers and all heads.
Refs:
- RoFormer: https://arxiv.org/abs/2104.09864
- VisionLLaMA: https://arxiv.org/abs/2403.00522
- https://github.com/Meituan-AutoML/VisionLLaMA/blob/main/dit/models.py
Args:
dim (int): usually the multi-head attention dimension, should be divisible by 4 (TODO: relax this constraint if needed)
max_height (int): the maximum height of the 2D grid
max_width (int): the maximum width of the 2D grid
theta_base (float): the base of the theta
device (str): the device to store the precomputed cis
"""
def __init__(
self, dim: int, max_height: int, max_width: int, theta_base=10000, device="cuda"
):
super().__init__()
self.dim = dim
assert self.dim % 4 == 0, "dim must be divisible by 4"
self.max_height = max_height
self.max_width = max_width
self.theta_base = theta_base
self.device = device
def extra_repr(self):
return f"dim={self.dim}, max_height={self.max_height}, max_width={self.max_width}, theta_base={self.theta_base}"
@cached_property
def precomputed_freqs_cis(self) -> torch.Tensor:
"""Calculate the cis(freqs) for each position in the 2D grid.
Return: complex tensor of shape (max_height, max_width, dim//2) and value:
height axis: ret[h, w, 2*i] = cis(h * theta_base**(-4*i/dim))
weight axis: ret[h, w, 2*i+1] = cis(w * theta_base**(-4*i/dim)) with (i in [0, dim//4))
note: `cis` is a mathematical notation defined by cis x = cos x + i sin x,
"""
N = self.max_height * self.max_width
flat_pos = torch.arange(0, N).float().to(self.device)
x_pos = flat_pos % self.max_width
y_pos = flat_pos // self.max_width
dim_range = (
torch.arange(0, self.dim, 4)[: (self.dim // 4)].float().to(self.device)
) # C/4
freqs = 1.0 / (self.theta_base ** (dim_range / self.dim))
x_freqs = torch.outer(x_pos, freqs).float() # N, C/4
y_freqs = torch.outer(y_pos, freqs).float() # N, C/4
x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs) # N, C/4
y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs) # N, C/4
# N, C/4, 2
freqs_cis = torch.cat(
[x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1
)
# max_height, max_width, C/2
freqs_cis = freqs_cis.reshape(self.max_height, self.max_width, -1)
return freqs_cis
def get_freqs_cis_by_seqlens(self, grid_hws: torch.Tensor) -> torch.Tensor:
"""
Args:
grid_hws (torch.Tensor): containing list of (height, width) or (t, height, width) tuples.
Returns:
freqs_cis: tensor of shape (sum(t * height * width), dim//2)
"""
shapes = grid_hws.tolist()
assert all(
1 <= h <= self.max_height and 1 <= w <= self.max_width for h, w in shapes
), (
shapes,
self.max_height,
self.max_width,
)
freqs_cis = torch.cat(
[
self.precomputed_freqs_cis[:h, :w].reshape(-1, self.dim // 2)
for h, w in shapes
],
dim=0,
)
return freqs_cis
def get_freqs_cis_by_idx(
self, pos_idx: torch.Tensor, pos_idx_mask: torch.Tensor
) -> torch.Tensor:
"""
Args:
pos_idx: tensor of shape (..., 2), It contains the (h, w) position indices of each 2D token.
pos_idx_mask: a mask of shape (...), the leading dimensions should be the same as pos_idx.
Rope will only be applied to the tokens with True mask. `freqs_cis` for the tokens with False mask with be ones.
Return:
freqs_cis: tensor of shape (..., dim//2)
"""
assert (
pos_idx.shape[:-1] == pos_idx_mask.shape
and pos_idx.shape[-1] == 2
and pos_idx.ndim == pos_idx_mask.ndim + 1
), (pos_idx.shape, pos_idx_mask.shape)
assert pos_idx_mask.dtype == torch.bool, pos_idx_mask.dtype
shp = pos_idx_mask.shape + (self.dim // 2,) # ..., head_dim/2
freqs_cis = torch.ones(
shp, dtype=torch.complex64, device=self.device
) # ..., head_dim/2
freqs_cis[pos_idx_mask] = self.precomputed_freqs_cis[
pos_idx[..., 0][pos_idx_mask], pos_idx[..., 1][pos_idx_mask]
]
return freqs_cis
class MLP2(nn.Module):
"""
Args:
dims: [in_dim, hidden_dim, out_dim]
bias: whether to use bias in linear layer.
"""
def __init__(self, dims: list[int], activation, bias=True):
super().__init__()
assert len(dims) == 3
self.fc0 = nn.Linear(dims[0], dims[1], bias=bias)
self.fc1 = nn.Linear(dims[1], dims[2], bias=bias)
self.activation = activation
for m in [self.fc0, self.fc1]:
nn.init.trunc_normal_(m.weight, std=math.sqrt(2 / m.in_features))
if m.bias is not None:
nn.init.zeros_(m.bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.fc0(x)
x = self.activation(x)
return self.fc1(x)
class MoonVitEncoderLayer(nn.Module):
def __init__(
self,
num_heads: int,
hidden_dim: int,
mlp_dim: int,
*,
attn_implementation: str = "flash_attention_2", # use fa2 in sglang by default
activation=F.gelu,
attn_bias: bool = False,
):
super().__init__()
self.num_heads = num_heads
self.hidden_dim = hidden_dim
self.hidden_size_per_attention_head = self.hidden_dim // self.num_heads
self.attn_implementation = attn_implementation
self.norm0 = nn.LayerNorm(hidden_dim)
self.norm1 = nn.LayerNorm(hidden_dim)
self.mlp = MLP2([hidden_dim, mlp_dim, hidden_dim], activation)
self.wqkv = nn.Linear(hidden_dim, hidden_dim * 3, bias=attn_bias)
self.wo = nn.Linear(hidden_dim, hidden_dim, bias=attn_bias)
def attention_qkvpacked(
self,
x: torch.Tensor,
cu_seqlens: torch.Tensor,
rope_freqs_cis: Optional[torch.Tensor] = None,
):
"""
Args:
x (torch.Tensor): (batch_size, seqlen, hidden_dim)
cu_seqlens (torch.Tensor):
"""
xqkv = self.wqkv(x)
qkv_shape = xqkv.size()[:-1] + (
3,
self.num_heads,
self.hidden_size_per_attention_head,
)
# xqkv: (batch_size, seqlen, 3, nheads, headdim)
xqkv = xqkv.view(*qkv_shape)
xq, xk, xv = torch.unbind(xqkv, dim=-3)
xq, xk = apply_rope(xq, xk, rope_freqs_cis)
attn_func = VL_VISION_ATTENTION_FUNCTIONS[self.attn_implementation]
attn_out = attn_func(
xq, xk, xv, q_cu_seqlens=cu_seqlens, k_cu_seqlens=cu_seqlens
)
attn_out = self.wo(attn_out)
return attn_out
def forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
rope_freqs_cis: Union[torch.Tensor, None] = None,
) -> torch.Tensor:
"""
Args:
hidden_states: non-packed (B, N, D) or packed (L, D). if non-packed, seqlens should be None, if packed, seqlens should be set
Returns:
output: same shape of input, non-packed (B, N, D) for non-packed input, (L, D) for packed input
"""
residual = hidden_states
hidden_states = self.norm0(hidden_states)
attn_out = self.attention_qkvpacked(
hidden_states, cu_seqlens, rope_freqs_cis=rope_freqs_cis
)
hidden_states = residual + attn_out
residual = hidden_states
hidden_states = self.mlp(self.norm1(hidden_states))
hidden_states = residual + hidden_states
return hidden_states
class MoonVitEncoder(nn.Module):
def __init__(
self,
hidden_dim: int,
num_layers: int,
block_cfg: dict,
) -> None:
super().__init__()
self.rope_2d = Rope2DPosEmb(
block_cfg["hidden_dim"] // block_cfg["num_heads"], 512, 512
)
self.blocks = nn.ModuleList(
[MoonVitEncoderLayer(**block_cfg) for _ in range(num_layers)]
)
self.final_layernorm = nn.LayerNorm(hidden_dim)
def forward(
self, hidden_states: torch.Tensor, grid_hw: torch.Tensor
) -> torch.Tensor:
rope_freqs_cis = self.rope_2d.get_freqs_cis_by_seqlens(grid_hws=grid_hw)
lengths = torch.cat(
(
torch.zeros(1, device=hidden_states.device, dtype=grid_hw.dtype),
grid_hw[:, 0] * grid_hw[:, 1],
)
)
cu_seqlens = lengths.cumsum(dim=0, dtype=torch.int32)
for _, block in enumerate(self.blocks):
hidden_states = block(
hidden_states, cu_seqlens, rope_freqs_cis=rope_freqs_cis
)
hidden_states = self.final_layernorm(hidden_states)
return hidden_states
def patch_merger(
x: torch.Tensor,
grid_hw: torch.Tensor,
merge_kernel_size: list[int, int] = (2, 2),
) -> List[torch.Tensor]:
d_model = x.size(-1)
outputs = []
pre_sum = 0
for x_shape in grid_hw.tolist():
height, width = x_shape[0], x_shape[1]
# Get the current sequence
seq = x[pre_sum : pre_sum + height * width]
# Reshape along self.merge_kernel_size and concat to the last dimension
kernel_height, kernel_width = merge_kernel_size
new_height, new_width = height // kernel_height, width // kernel_width
reshaped_seq = seq.view(
new_height, kernel_height, new_width, kernel_width, d_model
)
reshaped_seq = reshaped_seq.permute(0, 2, 1, 3, 4).contiguous()
padded_seq = reshaped_seq.view(
new_height * new_width, kernel_height * kernel_width, -1
)
outputs.append(padded_seq)
pre_sum += height * width
return outputs
class MoonVitVLProjector(nn.Module):
def __init__(
self,
in_channels: int,
merge_kernel_size: list[int, int],
hidden_act: str = "gelu",
ln_eps: float = 1e-5,
out_dim: int = 4096,
):
super().__init__()
self.hidden_size = in_channels * merge_kernel_size[0] * merge_kernel_size[1]
self.pre_norm = nn.nn.LayerNorm(in_channels, eps=ln_eps)
self.linear_1 = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
self.act = ACT2FN[hidden_act]
self.linear_2 = nn.Linear(self.hidden_size, out_dim, bias=True)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.pre_norm(hidden_states).view(-1, self.hidden_size)
hidden_states = self.linear_1(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
class MoonVitPretrainedModel(PreTrainedModel):
config_class = MoonViTConfig
model_type = "moonvit"
_no_split_modules = ["PackingTransformer"]
_supports_flash_attn_2 = True
_supports_sdpa = True
def __init__(self, config: MoonViTConfig, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
config = deepcopy(config)
self.merge_kernel_size = config.merge_kernel_size
self.patch_size = config.patch_size
self.patch_embed = MoonVisionPatchEmbed(
out_dim=config.hidden_size,
patch_size=config.patch_size,
pos_emb_height=config.init_pos_emb_height,
pos_emb_width=config.init_pos_emb_width,
)
self.encoder = MoonVitEncoder(
hidden_dim=config.hidden_size,
num_layers=config.num_hidden_layers,
block_cfg={
"num_heads": config.num_attention_heads,
"hidden_dim": config.hidden_size,
"mlp_dim": config.intermediate_size,
"activation": PytorchGELUTanh(),
"attn_bias": True,
"attn_implementation": config._attn_implementation,
},
)
def forward(
self, pixel_values: torch.Tensor, grid_hw: torch.Tensor
) -> torch.Tensor:
"""
Args:
pixel_values (torch.Tensor): The input pixel values.
grid_hw (torch.Tensor): The grid height and width.
Returns:
torch.Tensor: The output tokens.
"""
hidden_states = self.patch_embed(pixel_values, grid_hw)
hidden_states = self.encoder(hidden_states, grid_hw)
hidden_states = patch_merger(
hidden_states, grid_hw, merge_kernel_size=self.merge_kernel_size
)
return hidden_states
......@@ -81,10 +81,20 @@ class TestOpenAIVisionServer(CustomTestCase):
text = response.choices[0].message.content
assert isinstance(text, str)
# `driver` is for gemma-3-it
assert "man" in text or "person" or "driver" in text, text
assert "cab" in text or "taxi" in text or "SUV" in text, text
assert (
"man" in text or "person" or "driver" in text
), f"text: {text}, should contain man, person or driver"
assert (
"cab" in text
or "taxi" in text
or "SUV" in text
or "vehicle" in text
or "car" in text
), f"text: {text}, should contain cab, taxi, SUV, vehicle or car"
# MiniCPMO fails to recognize `iron`, but `hanging`
assert "iron" in text or "hang" in text, text
assert (
"iron" in text or "hang" in text or "cloth" in text or "holding" in text
), f"text: {text}, should contain iron, hang, cloth or holding"
assert response.id
assert response.created
assert response.usage.prompt_tokens > 0
......@@ -132,7 +142,9 @@ class TestOpenAIVisionServer(CustomTestCase):
assert response.choices[0].message.role == "assistant"
text = response.choices[0].message.content
assert isinstance(text, str)
assert "man" in text or "cab" in text, text
assert (
"man" in text or "cab" in text
), f"text: {text}, should contain man or cab"
assert response.id
assert response.created
assert response.usage.prompt_tokens > 0
......@@ -175,8 +187,12 @@ class TestOpenAIVisionServer(CustomTestCase):
print("-" * 30)
print(f"Multi images response:\n{text}")
print("-" * 30)
assert "man" in text or "cab" in text or "SUV" in text or "taxi" in text, text
assert "logo" in text or '"S"' in text or "SG" in text, text
assert (
"man" in text or "cab" in text or "SUV" in text or "taxi" in text
), f"text: {text}, should contain man, cab, SUV or taxi"
assert (
"logo" in text or '"S"' in text or "SG" in text
), f"text: {text}, should contain logo, S or SG"
assert response.id
assert response.created
assert response.usage.prompt_tokens > 0
......@@ -305,9 +321,9 @@ class TestOpenAIVisionServer(CustomTestCase):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
regex = (
r"""\{\n"""
+ r""" "color": "[\w]+",\n"""
+ r""" "number_of_cars": [\d]+\n"""
r"""\{"""
+ r""""color":"[\w]+","""
+ r""""number_of_cars":[\d]+"""
+ r"""\}"""
)
......@@ -732,5 +748,33 @@ class TestGemma3itServer(TestOpenAIVisionServer):
pass
class TestKimiVLServer(TestOpenAIVisionServer):
@classmethod
def setUpClass(cls):
cls.model = "moonshotai/Kimi-VL-A3B-Instruct"
cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-123456"
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--trust-remote-code",
"--chat-template",
"kimi-vl",
"--context-length",
"4096",
"--tensor-parallel-size",
"2",
"--dtype",
"bfloat16",
],
)
cls.base_url += "/v1"
def test_video_chat_completion(self):
pass
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment