Unverified Commit 4061dcf4 authored by Isotr0py's avatar Isotr0py Committed by GitHub
Browse files

[Bugfix] Enable Kimi k25 processor test (#33562)


Signed-off-by: default avatarIsotr0py <mozf@mail2.sysu.edu.cn>
parent 0aca8b8c
...@@ -995,6 +995,31 @@ def run_kimi_vl(questions: list[str], modality: str) -> ModelRequestData: ...@@ -995,6 +995,31 @@ def run_kimi_vl(questions: list[str], modality: str) -> ModelRequestData:
) )
# Kimi-VL
def run_kimi_k25(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "vision_chunk"
prompts = [
"<|im_user|>user<|media_begin|>image<|media_content|>"
f"<|media_pad|><|media_end|>{question}<|im_end|>"
"<|im_assistant|>assistant<|im_middle|>"
for question in questions
]
engine_args = EngineArgs(
model="moonshotai/Kimi-K2.5",
trust_remote_code=True,
max_model_len=4096,
limit_mm_per_prompt={modality: 1},
tensor_parallel_size=4,
)
return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
)
# LightOnOCR # LightOnOCR
def run_lightonocr(questions: list[str], modality: str) -> ModelRequestData: def run_lightonocr(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image" assert modality == "image"
...@@ -2110,6 +2135,7 @@ model_example_map = { ...@@ -2110,6 +2135,7 @@ model_example_map = {
"keye_vl": run_keye_vl, "keye_vl": run_keye_vl,
"keye_vl1_5": run_keye_vl1_5, "keye_vl1_5": run_keye_vl1_5,
"kimi_vl": run_kimi_vl, "kimi_vl": run_kimi_vl,
"kimi_k25": run_kimi_k25,
"lightonocr": run_lightonocr, "lightonocr": run_lightonocr,
"lfm2_vl": run_lfm2_vl, "lfm2_vl": run_lfm2_vl,
"llama4": run_llama4, "llama4": run_llama4,
...@@ -2196,6 +2222,19 @@ def get_multi_modal_input(args): ...@@ -2196,6 +2222,19 @@ def get_multi_modal_input(args):
"questions": vid_questions, "questions": vid_questions,
} }
if args.modality == "vision_chunk":
# Input vision chunks and question
image = convert_image_mode(ImageAsset("cherry_blossom").pil_image, "RGB")
vision_chunk_questions = [
"What is the content of this image chunk?",
"Describe the content of this image chunk in detail.",
]
return {
"data": {"type": "image", "image": image},
"questions": vision_chunk_questions,
}
msg = f"Modality {args.modality} is not supported." msg = f"Modality {args.modality} is not supported."
raise ValueError(msg) raise ValueError(msg)
...@@ -2278,7 +2317,7 @@ def parse_args(): ...@@ -2278,7 +2317,7 @@ def parse_args():
"--modality", "--modality",
type=str, type=str,
default="image", default="image",
choices=["image", "video"], choices=["image", "video", "vision_chunk"],
help="Modality of the input.", help="Modality of the input.",
) )
parser.add_argument( parser.add_argument(
...@@ -2355,7 +2394,7 @@ def main(args): ...@@ -2355,7 +2394,7 @@ def main(args):
req_data = model_example_map[model](questions, modality) req_data = model_example_map[model](questions, modality)
# Disable other modalities to save memory # Disable other modalities to save memory
default_limits = {"image": 0, "video": 0, "audio": 0} default_limits = {"image": 0, "video": 0, "audio": 0, "vision_chunk": 0}
req_data.engine_args.limit_mm_per_prompt = default_limits | dict( req_data.engine_args.limit_mm_per_prompt = default_limits | dict(
req_data.engine_args.limit_mm_per_prompt or {} req_data.engine_args.limit_mm_per_prompt or {}
) )
......
...@@ -214,6 +214,28 @@ def get_text_token_prompts( ...@@ -214,6 +214,28 @@ def get_text_token_prompts(
return text_prompt, token_prompt return text_prompt, token_prompt
def random_vision_chunk(
rng: np.random.RandomState,
min_wh: int,
max_wh: int,
min_frames: int,
max_frames: int,
) -> dict:
num_frames = rng.randint(min_frames, max_frames + 1)
if num_frames == 1:
# Single image chunk
wh = rng.randint(min_wh, max_wh + 1)
image = random_image(rng, wh, wh + 1)
return {"type": "image", "image": image}
frames = []
for _ in range(num_frames):
wh = rng.randint(min_wh, max_wh + 1)
frame = rng.randint(0, 256, size=(wh, wh, 3), dtype=np.uint8)
frames.append(frame)
video_array = np.stack(frames, axis=0)
return {"type": "video_chunk", "video_chunk": video_array}
def _test_processing_correctness( def _test_processing_correctness(
model_id_or_arch: str, model_id_or_arch: str,
hit_rate: float, hit_rate: float,
...@@ -291,6 +313,7 @@ def _test_processing_correctness( ...@@ -291,6 +313,7 @@ def _test_processing_correctness(
"image": Image.new("RGB", size=(128, 128)), "image": Image.new("RGB", size=(128, 128)),
"video": np.zeros((4, 128, 128, 3), dtype=np.uint8), "video": np.zeros((4, 128, 128, 3), dtype=np.uint8),
"audio": (np.zeros((512,)), 16000), "audio": (np.zeros((512,)), 16000),
"vision_chunk": {"type": "image", "image": Image.new("RGB", size=(128, 128))},
} }
input_factory = { input_factory = {
"image": partial(random_image, rng, min_wh=128, max_wh=256), "image": partial(random_image, rng, min_wh=128, max_wh=256),
...@@ -298,6 +321,9 @@ def _test_processing_correctness( ...@@ -298,6 +321,9 @@ def _test_processing_correctness(
random_video, rng, min_frames=2, max_frames=16, min_wh=128, max_wh=256 random_video, rng, min_frames=2, max_frames=16, min_wh=128, max_wh=256
), ),
"audio": partial(random_audio, rng, min_len=512, max_len=1024, sr=16000), "audio": partial(random_audio, rng, min_len=512, max_len=1024, sr=16000),
"vision_chunk": partial(
random_vision_chunk, rng, min_wh=128, max_wh=256, min_frames=1, max_frames=1
),
} }
for batch_idx in range(num_batches): for batch_idx in range(num_batches):
...@@ -413,11 +439,6 @@ def test_processing_correctness( ...@@ -413,11 +439,6 @@ def test_processing_correctness(
"Qwen-VL tokenizer requires downloading a font file from " "Qwen-VL tokenizer requires downloading a font file from "
"servers that often refuse connections in CI" "servers that often refuse connections in CI"
) )
if model_id == "moonshotai/Kimi-K2.5":
# FIXME(Isaac): Fix Kimi-K2.5's offline inference about vision chunks.
pytest.skip(
"Kimi-K2.5's offline inference has issues about vision chunks. Fix later."
)
_test_processing_correctness( _test_processing_correctness(
model_id, model_id,
......
...@@ -96,16 +96,20 @@ class MoonshotKimiVAutoProcessor(ProcessorMixin): ...@@ -96,16 +96,20 @@ class MoonshotKimiVAutoProcessor(ProcessorMixin):
attributes = ["tokenizer"] attributes = ["tokenizer"]
tokenizer_class = "AutoTokenizer" tokenizer_class = "AutoTokenizer"
def __init__(self, media_processor=None, tokenizer=None): def __init__(
self, media_processor=None, tokenizer=None, media_token_id: int | None = None
):
super().__init__(tokenizer) super().__init__(tokenizer)
self.media_processor = media_processor self.media_processor = media_processor
self.media_token_id = media_token_id
assert self.media_token_id is not None
# We do not support str input for text here # We do not support str input for text here
def __call__( def __call__(
self, self,
vision_chunks: list[VisionChunk] | None = None, vision_chunks: list[VisionChunk] | None = None,
*, *,
text: list[int], text: list[int] | str,
**kwargs, **kwargs,
) -> BatchFeature: ) -> BatchFeature:
""" """
...@@ -122,13 +126,30 @@ class MoonshotKimiVAutoProcessor(ProcessorMixin): ...@@ -122,13 +126,30 @@ class MoonshotKimiVAutoProcessor(ProcessorMixin):
- **grid_thws** -- list of image 3D grid in LLM. Returned when `vision_chunks` is not `None`. - **grid_thws** -- list of image 3D grid in LLM. Returned when `vision_chunks` is not `None`.
""" """
mm_inputs = {} mm_inputs = {}
input_ids = self.tokenizer.encode(text) if isinstance(text, str) else text
if vision_chunks is not None: if vision_chunks is not None:
assert isinstance(vision_chunks, list) assert isinstance(vision_chunks, list)
mm_inputs = self.media_processor.preprocess(vision_chunks) mm_inputs = self.media_processor.preprocess(vision_chunks)
num_tokens_per_chunk = [
self.media_processor.media_tokens_calculator(chunk)
for chunk in vision_chunks
]
new_input_ids = []
for token in input_ids:
if token == self.media_token_id:
new_input_ids.extend(
[self.media_token_id] * num_tokens_per_chunk.pop(0)
)
else:
new_input_ids.append(token)
input_ids = new_input_ids
# XXX: _apply_hf_processor_text_mm will call tolist() on input_ids # XXX: _apply_hf_processor_text_mm will call tolist() on input_ids
return BatchFeature( return BatchFeature(
data={ data={
"input_ids": torch.tensor([text]), "input_ids": torch.tensor([input_ids]),
**mm_inputs, **mm_inputs,
} }
) )
...@@ -152,6 +173,7 @@ class KimiK25ProcessingInfo(BaseProcessingInfo): ...@@ -152,6 +173,7 @@ class KimiK25ProcessingInfo(BaseProcessingInfo):
self.hf_processor = MoonshotKimiVAutoProcessor( self.hf_processor = MoonshotKimiVAutoProcessor(
media_processor=self.media_processor, media_processor=self.media_processor,
tokenizer=self.get_tokenizer(), tokenizer=self.get_tokenizer(),
media_token_id=self.media_token_id,
) )
self.media_tokens_calculator = self.media_processor.media_tokens_calculator self.media_tokens_calculator = self.media_processor.media_tokens_calculator
...@@ -174,9 +196,9 @@ class KimiK25DummyInputsBuilder(BaseDummyInputsBuilder[KimiK25ProcessingInfo]): ...@@ -174,9 +196,9 @@ class KimiK25DummyInputsBuilder(BaseDummyInputsBuilder[KimiK25ProcessingInfo]):
self.media_token_id = self.info.media_token_id self.media_token_id = self.info.media_token_id
self.frame_per_chunk = self.info.media_processor.num_frames_per_chunk self.frame_per_chunk = self.info.media_processor.num_frames_per_chunk
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> list[int]: def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_media = mm_counts.get("vision_chunk", 0) num_media = mm_counts.get("vision_chunk", 0)
return [self.media_token_id] * num_media return "<|media_pad|>" * num_media
def get_dummy_mm_items(self): def get_dummy_mm_items(self):
dummy_videos = self._get_dummy_images( dummy_videos = self._get_dummy_images(
......
...@@ -668,6 +668,8 @@ class MultiModalDataParser: ...@@ -668,6 +668,8 @@ class MultiModalDataParser:
return None return None
if self.is_embeddings(data): if self.is_embeddings(data):
raise ValueError("Do not support embedding data for vision_chunk right now") raise ValueError("Do not support embedding data for vision_chunk right now")
if isinstance(data, dict):
data = [data]
return VisionChunkProcessorItems(data) return VisionChunkProcessorItems(data)
def _get_subparsers(self) -> Mapping[str, ModalityDataParser]: def _get_subparsers(self) -> Mapping[str, ModalityDataParser]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment