Unverified Commit 0d503090 authored by Lifu Huang's avatar Lifu Huang Committed by GitHub
Browse files

Supported precomputed feature for Kimi VL (#6599)

parent 501efc3d
......@@ -5,7 +5,7 @@ import multiprocessing as mp
import os
import re
from abc import ABC, abstractmethod
from typing import List, Optional, Tuple, Union
from typing import Any, List, Optional, Tuple, Union
import numpy as np
import torch
......@@ -382,3 +382,17 @@ class BaseMultimodalProcessor(ABC):
"Unsupported: mixture of multimodal inputs where some but not all are preprocessed."
)
return ret
@staticmethod
def _extract_processor_features(
items: List[Any], attr_name: str
) -> Optional[torch.Tensor]:
"""
Helper function to concat extracted attributes from processor output.
"""
values = [
getattr(item, attr_name)
for item in items
if getattr(item, attr_name) is not None
]
return torch.concat(values) if values else None
from typing import List, Union
import re
from typing import Any, Dict, List, Optional, Union
import torch
from sglang.srt.managers.multimodal_processors.base_processor import (
BaseMultimodalProcessor as SGLangBaseProcessor,
......@@ -17,20 +20,12 @@ class KimiVLImageProcessor(SGLangBaseProcessor):
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
self.IMAGE_TOKEN = "<|media_pad|>"
self.IMAGE_TOKEN_REGEX = re.compile(r"(?:<\|media_pad\|>)+")
self.im_token_id = _processor.tokenizer.convert_tokens_to_ids(self.IMAGE_TOKEN)
self.im_start = "<|media_start|>"
self.im_start_id = _processor.tokenizer.convert_tokens_to_ids(self.im_start)
self.im_end = "<|media_end|>"
self.im_end_id = _processor.tokenizer.convert_tokens_to_ids(self.im_end)
self.im_content = "<|media_content|>"
self.im_content_id = _processor.tokenizer.convert_tokens_to_ids(self.im_content)
async def process_mm_data_async(
self,
image_data: List[Union[str, bytes]],
image_data: List[Union[str, bytes, Dict]],
input_text,
request_obj,
max_req_input_len,
......@@ -45,30 +40,54 @@ class KimiVLImageProcessor(SGLangBaseProcessor):
base_output = self.load_mm_data(
prompt=input_text,
image_data=image_data,
multimodal_tokens=MultimodalSpecialTokens(image_token=self.IMAGE_TOKEN),
multimodal_tokens=MultimodalSpecialTokens(
image_token=self.IMAGE_TOKEN, image_token_regex=self.IMAGE_TOKEN_REGEX
),
max_req_input_len=max_req_input_len,
)
ret = self.process_mm_data(
input_text=base_output.input_text,
images=base_output.images,
)
input_ids = ret["input_ids"].flatten()
images_are_preprocessed = self.mm_inputs_are_preprocessed(base_output.images)
if not images_are_preprocessed:
ret = self.process_mm_data(
input_text=base_output.input_text,
images=base_output.images,
)
input_ids = ret["input_ids"].flatten()
image_grid_thws = ret["image_grid_hws"]
pixel_values = ret["pixel_values"]
precomputed_features = None
else:
input_ids = self._processor.tokenizer(
base_output.input_text,
return_tensors="pt",
add_special_tokens=True,
).input_ids.flatten()
image_grid_thws = self._extract_processor_features(
base_output.images, "image_grid_thws"
)
precomputed_features = self._extract_processor_features(
base_output.images, "precomputed_features"
)
pixel_values = self._extract_processor_features(
base_output.images, "pixel_values"
)
image_offsets = self.get_mm_items_offset(
input_ids=input_ids,
mm_token_id=self.im_token_id,
)
return {
"input_ids": input_ids.tolist(),
"mm_items": [
MultimodalDataItem(
pixel_values=ret["pixel_values"],
image_grid_thws=ret["image_grid_hws"],
pixel_values=pixel_values,
image_grid_thws=image_grid_thws,
precomputed_features=precomputed_features,
modality=Modality.IMAGE,
image_offsets=image_offsets,
)
],
"im_token_id": self.im_token_id,
"im_start_id": self.im_start_id,
"im_end_id": self.im_end_id,
"im_content_id": self.im_content_id,
}
......@@ -42,7 +42,8 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
audio_data=audio_data,
image_data=image_data,
multimodal_tokens=MultimodalSpecialTokens(
image_token=self.image_token, audio_token=self.audio_token
image_token=self.image_token,
audio_token=self.audio_token,
),
)
if base_output is None:
......
......@@ -144,31 +144,14 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
if base_output.images:
if images_are_preprocessed:
all_image_grid_thws = [
item.image_grid_thws
for item in base_output.images
if item.image_grid_thws is not None
]
all_pixel_values = [
item.pixel_values
for item in base_output.images
if item.pixel_values is not None
]
all_precomputed_features = [
item.precomputed_features
for item in base_output.images
if item.precomputed_features is not None
]
image_grid_thw = (
torch.concat(all_image_grid_thws) if all_image_grid_thws else None
image_grid_thw = self._extract_processor_features(
base_output.images, "image_grid_thws"
)
pixel_values = (
torch.concat(all_pixel_values) if all_pixel_values else None
precomputed_features = self._extract_processor_features(
base_output.images, "precomputed_features"
)
precomputed_features = (
torch.concat(all_precomputed_features)
if all_precomputed_features
else None
pixel_values = self._extract_processor_features(
base_output.images, "pixel_values"
)
else:
image_grid_thw = ret["image_grid_thw"]
......
......@@ -7,6 +7,7 @@ import requests
import torch
from PIL import Image
from transformers import (
AutoModel,
AutoProcessor,
Gemma3ForConditionalGeneration,
Qwen2_5_VLForConditionalGeneration,
......@@ -51,6 +52,7 @@ class VLMInputTestBase:
mem_fraction_static=0.8,
enable_multimodal=True,
disable_cuda_graph=True,
trust_remote_code=True,
)
def tearDown(self):
......@@ -183,5 +185,32 @@ class TestGemmaUnderstandsImage(VLMInputTestBase, unittest.IsolatedAsyncioTestCa
)
class TestKimiVLImageUnderstandsImage(
VLMInputTestBase, unittest.IsolatedAsyncioTestCase
):
model_path = "moonshotai/Kimi-VL-A3B-Instruct"
chat_template = "kimi-vl"
@classmethod
def _init_visual(cls):
model = AutoModel.from_pretrained(cls.model_path, trust_remote_code=True)
cls.vision_tower = model.vision_tower.eval().to(cls.device)
cls.mm_projector = model.multi_modal_projector.eval().to(cls.device)
cls.visual = lambda tokenizer_output: cls.mm_projector(
cls.vision_tower(
pixel_values=tokenizer_output["pixel_values"],
grid_hws=tokenizer_output["image_grid_hws"],
)
)
def _pixel_values_image_data(self, processor_output):
return dict(
modality="IMAGE",
image_grid_thws=processor_output["image_grid_hws"],
pixel_values=processor_output["pixel_values"],
)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment