Unverified Commit 0d503090 authored by Lifu Huang's avatar Lifu Huang Committed by GitHub
Browse files

Supported precomputed feature for Kimi VL (#6599)

parent 501efc3d
...@@ -5,7 +5,7 @@ import multiprocessing as mp ...@@ -5,7 +5,7 @@ import multiprocessing as mp
import os import os
import re import re
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List, Optional, Tuple, Union from typing import Any, List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
...@@ -382,3 +382,17 @@ class BaseMultimodalProcessor(ABC): ...@@ -382,3 +382,17 @@ class BaseMultimodalProcessor(ABC):
"Unsupported: mixture of multimodal inputs where some but not all are preprocessed." "Unsupported: mixture of multimodal inputs where some but not all are preprocessed."
) )
return ret return ret
@staticmethod
def _extract_processor_features(
items: List[Any], attr_name: str
) -> Optional[torch.Tensor]:
"""
Helper function to concat extracted attributes from processor output.
"""
values = [
getattr(item, attr_name)
for item in items
if getattr(item, attr_name) is not None
]
return torch.concat(values) if values else None
from typing import List, Union import re
from typing import Any, Dict, List, Optional, Union
import torch
from sglang.srt.managers.multimodal_processors.base_processor import ( from sglang.srt.managers.multimodal_processors.base_processor import (
BaseMultimodalProcessor as SGLangBaseProcessor, BaseMultimodalProcessor as SGLangBaseProcessor,
...@@ -17,20 +20,12 @@ class KimiVLImageProcessor(SGLangBaseProcessor): ...@@ -17,20 +20,12 @@ class KimiVLImageProcessor(SGLangBaseProcessor):
def __init__(self, hf_config, server_args, _processor): def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor) super().__init__(hf_config, server_args, _processor)
self.IMAGE_TOKEN = "<|media_pad|>" self.IMAGE_TOKEN = "<|media_pad|>"
self.IMAGE_TOKEN_REGEX = re.compile(r"(?:<\|media_pad\|>)+")
self.im_token_id = _processor.tokenizer.convert_tokens_to_ids(self.IMAGE_TOKEN) self.im_token_id = _processor.tokenizer.convert_tokens_to_ids(self.IMAGE_TOKEN)
self.im_start = "<|media_start|>"
self.im_start_id = _processor.tokenizer.convert_tokens_to_ids(self.im_start)
self.im_end = "<|media_end|>"
self.im_end_id = _processor.tokenizer.convert_tokens_to_ids(self.im_end)
self.im_content = "<|media_content|>"
self.im_content_id = _processor.tokenizer.convert_tokens_to_ids(self.im_content)
async def process_mm_data_async( async def process_mm_data_async(
self, self,
image_data: List[Union[str, bytes]], image_data: List[Union[str, bytes, Dict]],
input_text, input_text,
request_obj, request_obj,
max_req_input_len, max_req_input_len,
...@@ -45,30 +40,54 @@ class KimiVLImageProcessor(SGLangBaseProcessor): ...@@ -45,30 +40,54 @@ class KimiVLImageProcessor(SGLangBaseProcessor):
base_output = self.load_mm_data( base_output = self.load_mm_data(
prompt=input_text, prompt=input_text,
image_data=image_data, image_data=image_data,
multimodal_tokens=MultimodalSpecialTokens(image_token=self.IMAGE_TOKEN), multimodal_tokens=MultimodalSpecialTokens(
image_token=self.IMAGE_TOKEN, image_token_regex=self.IMAGE_TOKEN_REGEX
),
max_req_input_len=max_req_input_len, max_req_input_len=max_req_input_len,
) )
ret = self.process_mm_data(
input_text=base_output.input_text, images_are_preprocessed = self.mm_inputs_are_preprocessed(base_output.images)
images=base_output.images, if not images_are_preprocessed:
) ret = self.process_mm_data(
input_ids = ret["input_ids"].flatten() input_text=base_output.input_text,
images=base_output.images,
)
input_ids = ret["input_ids"].flatten()
image_grid_thws = ret["image_grid_hws"]
pixel_values = ret["pixel_values"]
precomputed_features = None
else:
input_ids = self._processor.tokenizer(
base_output.input_text,
return_tensors="pt",
add_special_tokens=True,
).input_ids.flatten()
image_grid_thws = self._extract_processor_features(
base_output.images, "image_grid_thws"
)
precomputed_features = self._extract_processor_features(
base_output.images, "precomputed_features"
)
pixel_values = self._extract_processor_features(
base_output.images, "pixel_values"
)
image_offsets = self.get_mm_items_offset( image_offsets = self.get_mm_items_offset(
input_ids=input_ids, input_ids=input_ids,
mm_token_id=self.im_token_id, mm_token_id=self.im_token_id,
) )
return { return {
"input_ids": input_ids.tolist(), "input_ids": input_ids.tolist(),
"mm_items": [ "mm_items": [
MultimodalDataItem( MultimodalDataItem(
pixel_values=ret["pixel_values"], pixel_values=pixel_values,
image_grid_thws=ret["image_grid_hws"], image_grid_thws=image_grid_thws,
precomputed_features=precomputed_features,
modality=Modality.IMAGE, modality=Modality.IMAGE,
image_offsets=image_offsets, image_offsets=image_offsets,
) )
], ],
"im_token_id": self.im_token_id, "im_token_id": self.im_token_id,
"im_start_id": self.im_start_id,
"im_end_id": self.im_end_id,
"im_content_id": self.im_content_id,
} }
...@@ -42,7 +42,8 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor): ...@@ -42,7 +42,8 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
audio_data=audio_data, audio_data=audio_data,
image_data=image_data, image_data=image_data,
multimodal_tokens=MultimodalSpecialTokens( multimodal_tokens=MultimodalSpecialTokens(
image_token=self.image_token, audio_token=self.audio_token image_token=self.image_token,
audio_token=self.audio_token,
), ),
) )
if base_output is None: if base_output is None:
......
...@@ -144,31 +144,14 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor): ...@@ -144,31 +144,14 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
if base_output.images: if base_output.images:
if images_are_preprocessed: if images_are_preprocessed:
all_image_grid_thws = [ image_grid_thw = self._extract_processor_features(
item.image_grid_thws base_output.images, "image_grid_thws"
for item in base_output.images
if item.image_grid_thws is not None
]
all_pixel_values = [
item.pixel_values
for item in base_output.images
if item.pixel_values is not None
]
all_precomputed_features = [
item.precomputed_features
for item in base_output.images
if item.precomputed_features is not None
]
image_grid_thw = (
torch.concat(all_image_grid_thws) if all_image_grid_thws else None
) )
pixel_values = ( precomputed_features = self._extract_processor_features(
torch.concat(all_pixel_values) if all_pixel_values else None base_output.images, "precomputed_features"
) )
precomputed_features = ( pixel_values = self._extract_processor_features(
torch.concat(all_precomputed_features) base_output.images, "pixel_values"
if all_precomputed_features
else None
) )
else: else:
image_grid_thw = ret["image_grid_thw"] image_grid_thw = ret["image_grid_thw"]
......
...@@ -7,6 +7,7 @@ import requests ...@@ -7,6 +7,7 @@ import requests
import torch import torch
from PIL import Image from PIL import Image
from transformers import ( from transformers import (
AutoModel,
AutoProcessor, AutoProcessor,
Gemma3ForConditionalGeneration, Gemma3ForConditionalGeneration,
Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration,
...@@ -51,6 +52,7 @@ class VLMInputTestBase: ...@@ -51,6 +52,7 @@ class VLMInputTestBase:
mem_fraction_static=0.8, mem_fraction_static=0.8,
enable_multimodal=True, enable_multimodal=True,
disable_cuda_graph=True, disable_cuda_graph=True,
trust_remote_code=True,
) )
def tearDown(self): def tearDown(self):
...@@ -183,5 +185,32 @@ class TestGemmaUnderstandsImage(VLMInputTestBase, unittest.IsolatedAsyncioTestCa ...@@ -183,5 +185,32 @@ class TestGemmaUnderstandsImage(VLMInputTestBase, unittest.IsolatedAsyncioTestCa
) )
class TestKimiVLImageUnderstandsImage(
VLMInputTestBase, unittest.IsolatedAsyncioTestCase
):
model_path = "moonshotai/Kimi-VL-A3B-Instruct"
chat_template = "kimi-vl"
@classmethod
def _init_visual(cls):
model = AutoModel.from_pretrained(cls.model_path, trust_remote_code=True)
cls.vision_tower = model.vision_tower.eval().to(cls.device)
cls.mm_projector = model.multi_modal_projector.eval().to(cls.device)
cls.visual = lambda tokenizer_output: cls.mm_projector(
cls.vision_tower(
pixel_values=tokenizer_output["pixel_values"],
grid_hws=tokenizer_output["image_grid_hws"],
)
)
def _pixel_values_image_data(self, processor_output):
return dict(
modality="IMAGE",
image_grid_thws=processor_output["image_grid_hws"],
pixel_values=processor_output["pixel_values"],
)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment