Unverified Commit 8fefdd32 authored by liwenju0's avatar liwenju0 Committed by GitHub
Browse files

[Feature] add support kimi vl model (#5383)


Co-authored-by: default avatarwenju.li <wenju.li@deepctr.cn>
parent 403b855a
...@@ -29,3 +29,4 @@ python3 -m sglang.launch_server \ ...@@ -29,3 +29,4 @@ python3 -m sglang.launch_server \
| **LLaVA-NeXT** (8B, 72B) | `lmms-lab/llava-next-72b` | `chatml-llava` | Improved LLaVA models (with an 8B Llama3 version and a 72B version) offering enhanced visual instruction-following and accuracy on multimodal benchmarks. | | **LLaVA-NeXT** (8B, 72B) | `lmms-lab/llava-next-72b` | `chatml-llava` | Improved LLaVA models (with an 8B Llama3 version and a 72B version) offering enhanced visual instruction-following and accuracy on multimodal benchmarks. |
| **LLaVA-OneVision** | `lmms-lab/llava-onevision-qwen2-7b-ov` | `chatml-llava` | Enhanced LLaVA variant integrating Qwen as the backbone; supports multiple images (and even video frames) as inputs via an OpenAI Vision API-compatible format. | | **LLaVA-OneVision** | `lmms-lab/llava-onevision-qwen2-7b-ov` | `chatml-llava` | Enhanced LLaVA variant integrating Qwen as the backbone; supports multiple images (and even video frames) as inputs via an OpenAI Vision API-compatible format. |
| **Gemma 3 (Multimodal)** | `google/gemma-3-4b-it` | `gemma-it` | Gemma 3’s larger models (4B, 12B, 27B) accept images (each image encoded as 256 tokens) alongside text in a combined 128K-token context. | | **Gemma 3 (Multimodal)** | `google/gemma-3-4b-it` | `gemma-it` | Gemma 3’s larger models (4B, 12B, 27B) accept images (each image encoded as 256 tokens) alongside text in a combined 128K-token context. |
| **Kimi-VL** (A3B) | `moonshotai/Kimi-VL-A3B-Instruct` | `kimi-vl` | Kimi-VL is a multimodal model that can understand and generate text from images. |
...@@ -42,6 +42,7 @@ runtime_common = [ ...@@ -42,6 +42,7 @@ runtime_common = [
"uvicorn", "uvicorn",
"uvloop", "uvloop",
"xgrammar==0.1.17", "xgrammar==0.1.17",
"blobfile==3.0.0"
] ]
srt = [ srt = [
......
...@@ -3,6 +3,8 @@ from sglang.srt.configs.dbrx import DbrxConfig ...@@ -3,6 +3,8 @@ from sglang.srt.configs.dbrx import DbrxConfig
from sglang.srt.configs.deepseekvl2 import DeepseekVL2Config from sglang.srt.configs.deepseekvl2 import DeepseekVL2Config
from sglang.srt.configs.exaone import ExaoneConfig from sglang.srt.configs.exaone import ExaoneConfig
from sglang.srt.configs.janus_pro import MultiModalityConfig from sglang.srt.configs.janus_pro import MultiModalityConfig
from sglang.srt.configs.kimi_vl import KimiVLConfig
from sglang.srt.configs.kimi_vl_moonvit import MoonViTConfig
__all__ = [ __all__ = [
"ExaoneConfig", "ExaoneConfig",
...@@ -10,4 +12,6 @@ __all__ = [ ...@@ -10,4 +12,6 @@ __all__ = [
"DbrxConfig", "DbrxConfig",
"DeepseekVL2Config", "DeepseekVL2Config",
"MultiModalityConfig", "MultiModalityConfig",
"KimiVLConfig",
"MoonViTConfig",
] ]
# SPDX-License-Identifier: Apache-2.0
# Adapted from https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct/blob/main/configuration_kimi_vl.py
from typing import Optional, Union
from transformers.configuration_utils import PretrainedConfig
from sglang.srt.configs.deepseekvl2 import DeepseekV2Config
from sglang.srt.configs.kimi_vl_moonvit import MoonViTConfig
class KimiVLConfig(PretrainedConfig):
model_type = "kimi_vl"
def __init__(
self,
vision_config: Optional[Union[dict, MoonViTConfig]] = None,
text_config: Optional[Union[dict, DeepseekV2Config]] = None,
ignore_index: int = -100,
media_placeholder_token_id: int = 163605,
pad_token_id: int = 0,
**kwargs
):
if vision_config is None:
vision_config = MoonViTConfig()
elif isinstance(vision_config, dict):
vision_config = MoonViTConfig(**vision_config)
self.vision_config = vision_config
if text_config is None:
text_config = DeepseekV2Config()
elif isinstance(text_config, dict):
text_config = DeepseekV2Config(**text_config)
self.text_config = text_config
self.ignore_index = ignore_index
self.media_placeholder_token_id = media_placeholder_token_id
super().__init__(pad_token_id=pad_token_id, **kwargs)
# SPDX-License-Identifier: Apache-2.0
# Adapted from https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct/blob/main/configuration_kimi_vl.py
from transformers.configuration_utils import PretrainedConfig
class MoonViTConfig(PretrainedConfig):
model_type = "moonvit"
def __init__(
self,
patch_size: int = 14,
init_pos_emb_height: int = 64,
init_pos_emb_width: int = 64,
num_attention_heads: int = 16,
num_hidden_layers: int = 27,
hidden_size: int = 1152,
intermediate_size: int = 4304,
merge_kernel_size: tuple[int, int] = (2, 2),
**kwargs,
):
super().__init__(**kwargs)
self.patch_size = patch_size
# Positional embedding config
self.init_pos_emb_height = init_pos_emb_height
self.init_pos_emb_width = init_pos_emb_width
# Transformer config
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
# Patch merger config
self.merge_kernel_size = merge_kernel_size
...@@ -176,6 +176,13 @@ class ModelConfig: ...@@ -176,6 +176,13 @@ class ModelConfig:
self.attention_arch = AttentionArch.MLA self.attention_arch = AttentionArch.MLA
self.kv_lora_rank = self.hf_text_config.kv_lora_rank self.kv_lora_rank = self.hf_text_config.kv_lora_rank
self.qk_rope_head_dim = self.hf_text_config.qk_rope_head_dim self.qk_rope_head_dim = self.hf_text_config.qk_rope_head_dim
elif "KimiVLForConditionalGeneration" in self.hf_config.architectures:
self.head_dim = 256
self.attention_arch = AttentionArch.MLA
self.kv_lora_rank = self.hf_text_config.kv_lora_rank
self.qk_rope_head_dim = self.hf_text_config.qk_rope_head_dim
self.v_head_dim = self.hf_text_config.v_head_dim
self.qk_nope_head_dim = self.hf_text_config.qk_nope_head_dim
else: else:
self.attention_arch = AttentionArch.MHA self.attention_arch = AttentionArch.MHA
...@@ -530,6 +537,7 @@ multimodal_model_archs = [ ...@@ -530,6 +537,7 @@ multimodal_model_archs = [
"Qwen2VLForConditionalGeneration", "Qwen2VLForConditionalGeneration",
"Qwen2_5_VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration",
"CLIPModel", "CLIPModel",
"KimiVLForConditionalGeneration",
] ]
......
...@@ -806,6 +806,24 @@ register_conv_template( ...@@ -806,6 +806,24 @@ register_conv_template(
) )
) )
# Reference: https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct/blob/main/chat_template.jinja
register_conv_template(
Conversation(
name="kimi-vl",
system_message="You are a helpful assistant",
system_template="<|im_system|>system<|im_middle|>{system_message}",
roles=(
"<|im_user|>user<|im_middle|>",
"<|im_assistant|>assistant<|im_middle|>",
),
messages=[],
sep="<|im_end|>",
sep_style=SeparatorStyle.NO_COLON_SINGLE,
stop_str="<|im_end|>",
image_token="<|media_start|>image<|media_content|><|media_pad|><|media_end|>",
)
)
@register_conv_template_matching_function @register_conv_template_matching_function
def match_deepseek_janus_pro(model_path: str): def match_deepseek_janus_pro(model_path: str):
...@@ -888,3 +906,10 @@ def match_openbmb_minicpm(model_path: str): ...@@ -888,3 +906,10 @@ def match_openbmb_minicpm(model_path: str):
return "minicpmv" return "minicpmv"
elif "minicpm-o" in model_path: elif "minicpm-o" in model_path:
return "minicpmo" return "minicpmo"
@register_conv_template_matching_function
def match_moonshot_kimivl(model_path: str):
model_path = model_path.lower()
if "kimi" in model_path and "vl" in model_path:
return "kimi-vl"
...@@ -35,6 +35,7 @@ from sglang.srt.configs import ( ...@@ -35,6 +35,7 @@ from sglang.srt.configs import (
DbrxConfig, DbrxConfig,
DeepseekVL2Config, DeepseekVL2Config,
ExaoneConfig, ExaoneConfig,
KimiVLConfig,
MultiModalityConfig, MultiModalityConfig,
) )
from sglang.srt.connector import create_remote_connector from sglang.srt.connector import create_remote_connector
...@@ -46,6 +47,7 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = { ...@@ -46,6 +47,7 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
ExaoneConfig.model_type: ExaoneConfig, ExaoneConfig.model_type: ExaoneConfig,
DeepseekVL2Config.model_type: DeepseekVL2Config, DeepseekVL2Config.model_type: DeepseekVL2Config,
MultiModalityConfig.model_type: MultiModalityConfig, MultiModalityConfig.model_type: MultiModalityConfig,
KimiVLConfig.model_type: KimiVLConfig,
} }
for name, cls in _CONFIG_REGISTRY.items(): for name, cls in _CONFIG_REGISTRY.items():
......
import asyncio
import math
from typing import List, Union
import torch
from PIL import Image
from sglang.srt.managers.multimodal_processors.base_processor import (
BaseMultimodalProcessor as SGLangBaseProcessor,
)
from sglang.srt.managers.multimodal_processors.base_processor import (
MultimodalSpecialTokens,
)
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
from sglang.srt.models.kimi_vl import KimiVLForConditionalGeneration
# Compatible with KimiVLForConditionalGeneration
class KimiVLImageProcessor(SGLangBaseProcessor):
models = [KimiVLForConditionalGeneration]
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
self.IMAGE_TOKEN = "<|media_pad|>"
self.im_token_id = _processor.tokenizer.convert_tokens_to_ids(self.IMAGE_TOKEN)
self.im_start = "<|media_start|>"
self.im_start_id = _processor.tokenizer.convert_tokens_to_ids(self.im_start)
self.im_end = "<|media_end|>"
self.im_end_id = _processor.tokenizer.convert_tokens_to_ids(self.im_end)
self.im_content = "<|media_content|>"
self.im_content_id = _processor.tokenizer.convert_tokens_to_ids(self.im_content)
async def process_mm_data_async(
self,
image_data: List[Union[str, bytes]],
input_text,
request_obj,
max_req_input_len,
*args,
**kwargs,
):
if not image_data:
return None
if isinstance(image_data, str):
image_data = [image_data]
base_output = self.load_mm_data(
prompt=input_text,
image_data=image_data,
multimodal_tokens=MultimodalSpecialTokens(image_token=self.IMAGE_TOKEN),
max_req_input_len=max_req_input_len,
)
ret = self.process_mm_data(
input_text=base_output.input_text,
images=base_output.images,
)
return {
"input_ids": ret["input_ids"].flatten().tolist(),
"mm_items": [
MultimodalDataItem(
pixel_values=ret["pixel_values"],
image_grid_thws=ret["image_grid_hws"],
modality=Modality.IMAGE,
)
],
"im_token_id": self.im_token_id,
"im_start_id": self.im_start_id,
"im_end_id": self.im_end_id,
"im_content_id": self.im_content_id,
}
...@@ -752,7 +752,7 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -752,7 +752,7 @@ class DeepseekV2AttentionMLA(nn.Module):
q_nope_out = q_nope_out.transpose(0, 1) q_nope_out = q_nope_out.transpose(0, 1)
k_nope = latent_cache[..., : self.kv_lora_rank] k_nope = latent_cache[..., : self.kv_lora_rank]
k_nope = self.kv_a_layernorm(k_nope).unsqueeze(1) k_nope = self.kv_a_layernorm(k_nope.contiguous()).unsqueeze(1)
k_pe = latent_cache[..., self.kv_lora_rank :].unsqueeze(1) k_pe = latent_cache[..., self.kv_lora_rank :].unsqueeze(1)
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
...@@ -1391,6 +1391,9 @@ class DeepseekV2Model(nn.Module): ...@@ -1391,6 +1391,9 @@ class DeepseekV2Model(nn.Module):
self.dp_size = get_attention_dp_size() self.dp_size = get_attention_dp_size()
def get_input_embeddings(self) -> torch.Tensor:
return self.embed_tokens
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
......
# SPDX-License-Identifier: Apache-2.0
# ruff: noqa: E501
# Adapted from https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct/blob/main/modeling_kimi_vl.py
# Copyright 2025 The Moonshot AI Team, DeepSeek-AI, and HuggingFace Inc. team. All rights reserved.
#
# The code is based on llava (llava/modeling_llava.py) and DeepSeek-V3 (DeepSeek-V3/modeling_deepseek.py), but modified for KimiVL.
#
# Licensing Information:
# - Code derived from llava (llava/modeling_llava.py) and DeepSeek-V3 (DeepSeek-V3/modeling_deepseek.py) is licensed under the Apache License, Version 2.0.
# - Other parts of the code are licensed under the MIT License.
#
# Apache License, Version 2.0:
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# MIT License:
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import copy
import logging
import math
from collections.abc import Mapping
from dataclasses import dataclass
from typing import Any, Iterable, List, Optional, Tuple
import torch
from torch import nn
from transformers.activations import GELUActivation
from sglang.srt.configs import KimiVLConfig
from sglang.srt.configs.deepseekvl2 import DeepseekV2Config
from sglang.srt.configs.kimi_vl import KimiVLConfig
from sglang.srt.configs.kimi_vl_moonvit import MoonViTConfig
from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from sglang.srt.layers.activation import QuickGELU
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternMultimodalTokens,
general_mm_embed_routine,
)
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import (
default_weight_loader,
maybe_remap_kv_scale_name,
)
from sglang.srt.models.deepseek_v2 import DeepseekV2ForCausalLM
from sglang.srt.models.kimi_vl_moonvit import MoonVitPretrainedModel
from sglang.srt.utils import add_prefix
logger = logging.getLogger(__name__)
# For dummy input only
@dataclass
class MaxImageTokenMeta:
width: int = 1024
height: int = 1024
class KimiVLMultiModalProjector(nn.Module):
def __init__(self, config: KimiVLConfig):
super().__init__()
self.hidden_size = (
config.vision_config.hidden_size
* config.vision_config.merge_kernel_size[0]
* config.vision_config.merge_kernel_size[1]
)
self.pre_norm = torch.nn.LayerNorm(config.vision_config.hidden_size, eps=1e-5)
self.linear_1 = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
self.act = GELUActivation()
self.act = QuickGELU()
self.linear_2 = nn.Linear(
self.hidden_size, config.text_config.hidden_size, bias=True
)
def forward(self, image_features: torch.Tensor) -> torch.Tensor:
hidden_states = self.pre_norm(image_features).view(-1, self.hidden_size)
hidden_states = self.linear_1(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
class KimiVLForConditionalGeneration(nn.Module):
def __init__(
self,
config: KimiVLConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
**kwargs, # fix init_tts argument error
) -> None:
super().__init__()
self.config = config
assert isinstance(config.vision_config, MoonViTConfig)
self.vision_tower = MoonVitPretrainedModel(config.vision_config)
self.multi_modal_projector = KimiVLMultiModalProjector(config=config)
self.quant_config = quant_config
text_config = copy.deepcopy(config.text_config)
text_config.architectures = ["DeepseekV2ForCausalLM"]
self.language_model = DeepseekV2ForCausalLM(
config=text_config,
quant_config=quant_config,
prefix=add_prefix("language_model", prefix),
)
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
pixel_values = (
torch.cat([item.pixel_values for item in items], dim=0)
.type(self.vision_tower.dtype)
.to(self.vision_tower.device)
)
image_grid_thws = torch.concat(
[item.image_grid_thws for item in items], dim=0
).to(self.vision_tower.device)
image_features = self.vision_tower(pixel_values, image_grid_thws)
assert isinstance(image_features, list)
# lengths = [x.shape[0] for x in image_features]
res = self.multi_modal_projector(torch.cat(image_features)) # .split(lengths)
return res
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
# Get all special token IDs
pattern = MultiModalityDataPaddingPatternMultimodalTokens(mm_inputs.im_token_id)
return pattern.pad_input_tokens(input_ids, mm_inputs)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
get_embedding: bool = False,
):
hidden_states = general_mm_embed_routine(
input_ids=input_ids,
forward_batch=forward_batch,
language_model=self.language_model,
image_data_embedding_func=self.get_image_feature,
positions=positions,
)
return hidden_states
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
config = self.config.text_config
_KEYS_TO_MODIFY_MAPPING = {
# "language_model.lm_head": "lm_head",
# "language_model.model": "language_model",
}
# only doing this for language model part for now.
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
if not config.use_mla:
stacked_params_mapping += [
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
]
if getattr(config, "n_routed_experts", None):
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=config.n_routed_experts,
)
else:
expert_params_mapping = []
params_dict = dict(self.named_parameters())
for args in weights:
name, loaded_weight = args[:2]
kwargs = args[2] if len(args) > 2 else {}
if "rotary_emb.inv_freq" in name:
continue
spec_layer = get_spec_layer_idx_from_weight_name(config, name)
if spec_layer is not None:
continue # skip spec decode layers for main model
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
if key_to_modify in name:
name = name.replace(key_to_modify, new_key)
use_default_weight_loading = False
if "vision" in name:
if self.vision_tower is not None:
# We only do sharding for language model and
# not vision model for now.
use_default_weight_loading = True
else:
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if ("mlp.experts." in name) and name not in params_dict:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id, **kwargs)
break
else:
for idx, (
param_name,
weight_name,
expert_id,
shard_id,
) in enumerate(expert_params_mapping):
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(
param,
loaded_weight,
name,
expert_id=expert_id,
shard_id=shard_id,
**kwargs,
)
break
else:
use_default_weight_loading = True
if use_default_weight_loading:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
# if is_pp_missing_parameter(name, self):
# continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight, **kwargs)
self.language_model.post_load_weights()
def get_spec_layer_idx_from_weight_name(
config: DeepseekV2Config, weight_name: str
) -> Optional[int]:
if hasattr(config, "num_nextn_predict_layers") and (
config.num_nextn_predict_layers > 0
):
layer_idx = config.num_hidden_layers
for i in range(config.num_nextn_predict_layers):
if weight_name.startswith(f"model.layers.{layer_idx+i}."):
return layer_idx + i
return None
EntryClass = [KimiVLForConditionalGeneration]
This diff is collapsed.
...@@ -81,10 +81,20 @@ class TestOpenAIVisionServer(CustomTestCase): ...@@ -81,10 +81,20 @@ class TestOpenAIVisionServer(CustomTestCase):
text = response.choices[0].message.content text = response.choices[0].message.content
assert isinstance(text, str) assert isinstance(text, str)
# `driver` is for gemma-3-it # `driver` is for gemma-3-it
assert "man" in text or "person" or "driver" in text, text assert (
assert "cab" in text or "taxi" in text or "SUV" in text, text "man" in text or "person" or "driver" in text
), f"text: {text}, should contain man, person or driver"
assert (
"cab" in text
or "taxi" in text
or "SUV" in text
or "vehicle" in text
or "car" in text
), f"text: {text}, should contain cab, taxi, SUV, vehicle or car"
# MiniCPMO fails to recognize `iron`, but `hanging` # MiniCPMO fails to recognize `iron`, but `hanging`
assert "iron" in text or "hang" in text, text assert (
"iron" in text or "hang" in text or "cloth" in text or "holding" in text
), f"text: {text}, should contain iron, hang, cloth or holding"
assert response.id assert response.id
assert response.created assert response.created
assert response.usage.prompt_tokens > 0 assert response.usage.prompt_tokens > 0
...@@ -132,7 +142,9 @@ class TestOpenAIVisionServer(CustomTestCase): ...@@ -132,7 +142,9 @@ class TestOpenAIVisionServer(CustomTestCase):
assert response.choices[0].message.role == "assistant" assert response.choices[0].message.role == "assistant"
text = response.choices[0].message.content text = response.choices[0].message.content
assert isinstance(text, str) assert isinstance(text, str)
assert "man" in text or "cab" in text, text assert (
"man" in text or "cab" in text
), f"text: {text}, should contain man or cab"
assert response.id assert response.id
assert response.created assert response.created
assert response.usage.prompt_tokens > 0 assert response.usage.prompt_tokens > 0
...@@ -175,8 +187,12 @@ class TestOpenAIVisionServer(CustomTestCase): ...@@ -175,8 +187,12 @@ class TestOpenAIVisionServer(CustomTestCase):
print("-" * 30) print("-" * 30)
print(f"Multi images response:\n{text}") print(f"Multi images response:\n{text}")
print("-" * 30) print("-" * 30)
assert "man" in text or "cab" in text or "SUV" in text or "taxi" in text, text assert (
assert "logo" in text or '"S"' in text or "SG" in text, text "man" in text or "cab" in text or "SUV" in text or "taxi" in text
), f"text: {text}, should contain man, cab, SUV or taxi"
assert (
"logo" in text or '"S"' in text or "SG" in text
), f"text: {text}, should contain logo, S or SG"
assert response.id assert response.id
assert response.created assert response.created
assert response.usage.prompt_tokens > 0 assert response.usage.prompt_tokens > 0
...@@ -305,9 +321,9 @@ class TestOpenAIVisionServer(CustomTestCase): ...@@ -305,9 +321,9 @@ class TestOpenAIVisionServer(CustomTestCase):
client = openai.Client(api_key=self.api_key, base_url=self.base_url) client = openai.Client(api_key=self.api_key, base_url=self.base_url)
regex = ( regex = (
r"""\{\n""" r"""\{"""
+ r""" "color": "[\w]+",\n""" + r""""color":"[\w]+","""
+ r""" "number_of_cars": [\d]+\n""" + r""""number_of_cars":[\d]+"""
+ r"""\}""" + r"""\}"""
) )
...@@ -732,5 +748,33 @@ class TestGemma3itServer(TestOpenAIVisionServer): ...@@ -732,5 +748,33 @@ class TestGemma3itServer(TestOpenAIVisionServer):
pass pass
class TestKimiVLServer(TestOpenAIVisionServer):
@classmethod
def setUpClass(cls):
cls.model = "moonshotai/Kimi-VL-A3B-Instruct"
cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-123456"
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--trust-remote-code",
"--chat-template",
"kimi-vl",
"--context-length",
"4096",
"--tensor-parallel-size",
"2",
"--dtype",
"bfloat16",
],
)
cls.base_url += "/v1"
def test_video_chat_completion(self):
pass
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment