Unverified Commit 2566dca2 authored by Isotr0py's avatar Isotr0py Committed by GitHub
Browse files

[Bugfix] Fix deepseek-ocr multi-image inference and add...


[Bugfix] Fix deepseek-ocr multi-image inference and add `merge_by_field_config=True` with tensor schema support (#27361)
Signed-off-by: default avatarIsotr0py <mozf@mail2.sysu.edu.cn>
parent b4fda58a
...@@ -44,6 +44,7 @@ class ModelRequestData(NamedTuple): ...@@ -44,6 +44,7 @@ class ModelRequestData(NamedTuple):
stop_token_ids: list[int] | None = None stop_token_ids: list[int] | None = None
chat_template: str | None = None chat_template: str | None = None
lora_requests: list[LoRARequest] | None = None lora_requests: list[LoRARequest] | None = None
sampling_params: SamplingParams | None = None
# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on # NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on
...@@ -201,6 +202,46 @@ def load_deepseek_vl2(question: str, image_urls: list[str]) -> ModelRequestData: ...@@ -201,6 +202,46 @@ def load_deepseek_vl2(question: str, image_urls: list[str]) -> ModelRequestData:
) )
def load_deepseek_ocr(question: str, image_urls: list[str]) -> ModelRequestData:
from vllm.model_executor.models.deepseek_ocr import NGramPerReqLogitsProcessor
model_name = "deepseek-ai/DeepSeek-OCR"
engine_args = EngineArgs(
model=model_name,
max_num_seqs=2,
limit_mm_per_prompt={"image": len(image_urls)},
logits_processors=[NGramPerReqLogitsProcessor],
)
placeholder = "<image>\n" * len(image_urls)
prompt = placeholder + question
# The following sampling params config is taken from
# the official Deepseek-OCR inference example.
# (IMPORTANT) Use the custom logits processor and avoid skipping
# special tokens for this model for the optimal OCR performance.
sampling_params = SamplingParams(
temperature=0.0,
max_tokens=8192,
# ngram logit processor args
extra_args=dict(
ngram_size=30,
window_size=90,
# whitelist: <td>, </td>
whitelist_token_ids={128821, 128822},
),
skip_special_tokens=False,
)
return ModelRequestData(
engine_args=engine_args,
prompt=prompt,
image_data=[fetch_image(url) for url in image_urls],
sampling_params=sampling_params,
)
def load_gemma3(question: str, image_urls: list[str]) -> ModelRequestData: def load_gemma3(question: str, image_urls: list[str]) -> ModelRequestData:
model_name = "google/gemma-3-4b-it" model_name = "google/gemma-3-4b-it"
...@@ -1253,6 +1294,7 @@ model_example_map = { ...@@ -1253,6 +1294,7 @@ model_example_map = {
"bee": load_bee, "bee": load_bee,
"command_a_vision": load_command_a_vision, "command_a_vision": load_command_a_vision,
"deepseek_vl_v2": load_deepseek_vl2, "deepseek_vl_v2": load_deepseek_vl2,
"deepseek_ocr": load_deepseek_ocr,
"gemma3": load_gemma3, "gemma3": load_gemma3,
"h2ovl_chat": load_h2ovl, "h2ovl_chat": load_h2ovl,
"hyperclovax_seed_vision": load_hyperclovax_seed_vision, "hyperclovax_seed_vision": load_hyperclovax_seed_vision,
...@@ -1325,9 +1367,13 @@ def run_chat(model: str, question: str, image_urls: list[str], seed: int | None) ...@@ -1325,9 +1367,13 @@ def run_chat(model: str, question: str, image_urls: list[str], seed: int | None)
engine_args = asdict(req_data.engine_args) | {"seed": seed} engine_args = asdict(req_data.engine_args) | {"seed": seed}
llm = LLM(**engine_args) llm = LLM(**engine_args)
sampling_params = SamplingParams( sampling_params = (
SamplingParams(
temperature=0.0, max_tokens=256, stop_token_ids=req_data.stop_token_ids temperature=0.0, max_tokens=256, stop_token_ids=req_data.stop_token_ids
) )
if req_data.sampling_params is None
else req_data.sampling_params
)
outputs = llm.chat( outputs = llm.chat(
[ [
{ {
......
...@@ -332,6 +332,7 @@ def _test_processing_correctness_one( ...@@ -332,6 +332,7 @@ def _test_processing_correctness_one(
"facebook/chameleon-7b", "facebook/chameleon-7b",
"CohereLabs/command-a-vision-07-2025", "CohereLabs/command-a-vision-07-2025",
"deepseek-ai/deepseek-vl2-tiny", "deepseek-ai/deepseek-vl2-tiny",
"deepseek-ai/DeepSeek-OCR",
"baidu/ERNIE-4.5-VL-28B-A3B-PT", "baidu/ERNIE-4.5-VL-28B-A3B-PT",
"adept/fuyu-8b", "adept/fuyu-8b",
"google/gemma-3-4b-it", "google/gemma-3-4b-it",
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
import math import math
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from typing import Annotated, Literal
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -53,6 +54,7 @@ from vllm.transformers_utils.processors.deepseek_ocr import ( ...@@ -53,6 +54,7 @@ from vllm.transformers_utils.processors.deepseek_ocr import (
count_tiles, count_tiles,
) )
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from vllm.v1.sample.logits_processor import ( from vllm.v1.sample.logits_processor import (
AdapterLogitsProcessor, AdapterLogitsProcessor,
RequestLogitsProcessor, RequestLogitsProcessor,
...@@ -65,6 +67,28 @@ from .deepseek_vl2 import MlpProjector ...@@ -65,6 +67,28 @@ from .deepseek_vl2 import MlpProjector
_IMAGE_TOKEN = "<image>" _IMAGE_TOKEN = "<image>"
class DeepseekOCRImagePixelInputs(TensorSchema):
"""
Dimensions:
- b: Batch size
- n: Number of images
- p: Number of patches
- base_size: Base size of the processor
- image_size: Image size of the processor
"""
type: Literal["pixel_values"]
data: Annotated[
torch.Tensor,
TensorShape("bn", 3, "base_size", "base_size", dynamic_dims={"bnp"}),
]
images_crop: Annotated[
torch.Tensor,
TensorShape("bnp", 3, "image_size", "image_size", dynamic_dims={"bnp"}),
]
images_spatial_crop: Annotated[torch.Tensor, TensorShape("bn", 2)]
class NoRepeatNGramLogitsProcessor: class NoRepeatNGramLogitsProcessor:
def __init__( def __init__(
self, self,
...@@ -260,10 +284,15 @@ class DeepseekOCRMultiModalProcessor( ...@@ -260,10 +284,15 @@ class DeepseekOCRMultiModalProcessor(
hf_inputs: BatchFeature, hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]: ) -> Mapping[str, MultiModalFieldConfig]:
images_spatial_crop = hf_inputs.get("images_spatial_crop", torch.empty((0, 2)))
is_tiled = (images_spatial_crop[:, 0] > 1) | (images_spatial_crop[:, 1] > 1)
patches_per_image = torch.where(is_tiled, images_spatial_crop.prod(dim=-1), 0)
return dict( return dict(
pixel_values=MultiModalFieldConfig.batched("image"), pixel_values=MultiModalFieldConfig.batched("image"),
images_spatial_crop=MultiModalFieldConfig.batched("image"), images_spatial_crop=MultiModalFieldConfig.batched("image"),
images_crop=MultiModalFieldConfig.batched("image"), images_crop=MultiModalFieldConfig.flat_from_sizes(
"image", patches_per_image
),
) )
def _get_prompt_updates( def _get_prompt_updates(
...@@ -302,35 +331,6 @@ class DeepseekOCRMultiModalProcessor( ...@@ -302,35 +331,6 @@ class DeepseekOCRMultiModalProcessor(
) )
] ]
# TODO(Isotr0py): Check if we still need this workaround for
# deepseek-ocr processor.
# def _cached_apply_hf_processor(
# self,
# prompt: str | list[int],
# mm_data_items: MultiModalDataItems,
# hf_processor_mm_kwargs: Mapping[str, object],
# tokenization_kwargs: Mapping[str, object],
# mm_uuids: MultiModalUUIDDict | None = None,
# ) -> tuple[list[int], MultiModalKwargs, bool]:
# # The processor logic is different for len(images) <= 2 vs > 2
# # Since the processing cache assumes that the processor output is
# # invariant of how many images are passed per prompt, we only
# # perform caching for the most common case
# if mm_data_items.get_count("image", strict=False) > 2:
# # This code path corresponds to the cache being disabled
# return self._apply_hf_processor_main(
# prompt=prompt,
# mm_items=mm_data_items,
# hf_processor_mm_kwargs=hf_processor_mm_kwargs,
# enable_hf_prompt_update=True,
# )
# return super()._cached_apply_hf_processor(
# prompt=prompt,
# mm_data_items=mm_data_items,
# hf_processor_mm_kwargs=hf_processor_mm_kwargs,
# )
@MULTIMODAL_REGISTRY.register_processor( @MULTIMODAL_REGISTRY.register_processor(
DeepseekOCRMultiModalProcessor, DeepseekOCRMultiModalProcessor,
...@@ -338,6 +338,8 @@ class DeepseekOCRMultiModalProcessor( ...@@ -338,6 +338,8 @@ class DeepseekOCRMultiModalProcessor(
dummy_inputs=DeepseekOCRDummyInputsBuilder, dummy_inputs=DeepseekOCRDummyInputsBuilder,
) )
class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
merge_by_field_config = True
hf_to_vllm_mapper = WeightsMapper( hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={ orig_to_new_prefix={
# map prefix for language backbone # map prefix for language backbone
...@@ -389,6 +391,7 @@ class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -389,6 +391,7 @@ class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
self.vision_model = DeepCLIPVisionTransformer( self.vision_model = DeepCLIPVisionTransformer(
config=clip_vision_config, config=clip_vision_config,
quant_config=quant_config, quant_config=quant_config,
prefix=maybe_prefix(prefix, "vision_model"),
) )
self.projector = MlpProjector(self.projector_config) self.projector = MlpProjector(self.projector_config)
...@@ -426,7 +429,9 @@ class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -426,7 +429,9 @@ class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
self.language_model.make_empty_intermediate_tensors self.language_model.make_empty_intermediate_tensors
) )
def _parse_and_validate_image_input(self, **kwargs: object): def _parse_and_validate_image_input(
self, **kwargs: object
) -> DeepseekOCRImagePixelInputs | None:
pixel_values = kwargs.pop("pixel_values", None) pixel_values = kwargs.pop("pixel_values", None)
images_spatial_crop = kwargs.pop("images_spatial_crop", None) images_spatial_crop = kwargs.pop("images_spatial_crop", None)
images_crop = kwargs.pop("images_crop", None) images_crop = kwargs.pop("images_crop", None)
...@@ -435,24 +440,17 @@ class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -435,24 +440,17 @@ class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
return None return None
if pixel_values is not None: if pixel_values is not None:
if not isinstance(pixel_values, (torch.Tensor, list)): base_size = self.vision_config.image_size
raise ValueError( return DeepseekOCRImagePixelInputs(
f"Incorrect type of pixel values. Got type: {type(pixel_values)}" type="pixel_values",
) data=pixel_values,
images_crop=images_crop,
if not isinstance(images_spatial_crop, (torch.Tensor, list)): images_spatial_crop=images_spatial_crop,
raise ValueError( resolve_bindings={
"Incorrect type of image sizes. " "base_size": base_size,
f"Got type: {type(images_spatial_crop)}" },
)
if not isinstance(images_crop, (torch.Tensor, list)):
raise ValueError(
f"Incorrect type of image crop. Got type: {type(images_crop)}"
) )
return [pixel_values, images_crop, images_spatial_crop]
raise AssertionError("This line should be unreachable.") raise AssertionError("This line should be unreachable.")
def _encode_global_features(self, image_tensor: torch.Tensor) -> torch.Tensor: def _encode_global_features(self, image_tensor: torch.Tensor) -> torch.Tensor:
...@@ -518,10 +516,13 @@ class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -518,10 +516,13 @@ class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
) -> NestedTensors: ) -> NestedTensors:
images_in_this_batch = [] images_in_this_batch = []
is_tiled = (images_spatial_crop[:, 0] > 1) | (images_spatial_crop[:, 1] > 1)
patches_per_image = torch.where(is_tiled, images_spatial_crop.prod(dim=-1), 0)
images_crop = images_crop.split(patches_per_image.tolist())
for jdx in range(images_spatial_crop.size(0)): for jdx in range(images_spatial_crop.size(0)):
patches = images_crop[jdx][0].to(torch.bfloat16) patches = images_crop[jdx]
image_ori = pixel_values[jdx] image_ori = pixel_values[[jdx]]
crop_shape = images_spatial_crop[jdx][0] crop_shape = images_spatial_crop[jdx]
global_features = self._encode_global_features(image_ori) global_features = self._encode_global_features(image_ori)
local_features = self._encode_local_features(patches, crop_shape) local_features = self._encode_local_features(patches, crop_shape)
...@@ -540,10 +541,12 @@ class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -540,10 +541,12 @@ class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
return images_in_this_batch return images_in_this_batch
def _process_image_input(self, image_input) -> torch.Tensor: def _process_image_input(
pixel_values = image_input[0].to(torch.bfloat16) self, image_input: DeepseekOCRImagePixelInputs
images_crop = image_input[1] ) -> torch.Tensor:
images_spatial_crop = image_input[2].to(dtype=torch.long) pixel_values = image_input.data
images_crop = image_input.images_crop
images_spatial_crop = image_input.images_spatial_crop.to(dtype=torch.long)
vision_features = self._pixel_values_to_embedding( vision_features = self._pixel_values_to_embedding(
pixel_values=pixel_values, pixel_values=pixel_values,
......
...@@ -411,20 +411,16 @@ class DeepseekOCRProcessor(ProcessorMixin): ...@@ -411,20 +411,16 @@ class DeepseekOCRProcessor(ProcessorMixin):
images_seq_mask = images_seq_mask[:-1] images_seq_mask = images_seq_mask[:-1]
if len(images_list) == 0: if len(images_list) == 0:
pixel_values = torch.zeros((1, 3, self.base_size, self.base_size)) pixel_values = torch.zeros((0, 3, self.base_size, self.base_size))
images_spatial_crop = torch.zeros((1, 1), dtype=torch.long) images_spatial_crop = torch.zeros((0, 2), dtype=torch.long)
images_crop = torch.zeros( images_crop = torch.zeros((0, 3, self.image_size, self.image_size))
(1, 3, self.image_size, self.image_size)
).unsqueeze(0)
else: else:
pixel_values = torch.stack(images_list, dim=0) pixel_values = torch.stack(images_list, dim=0)
images_spatial_crop = torch.tensor(images_spatial_crop, dtype=torch.long) images_spatial_crop = torch.tensor(images_spatial_crop, dtype=torch.long)
if images_crop_list: if images_crop_list:
images_crop = torch.stack(images_crop_list, dim=0).unsqueeze(0) images_crop = torch.stack(images_crop_list, dim=0)
else: else:
images_crop = torch.zeros( images_crop = torch.zeros((0, 3, self.image_size, self.image_size))
(1, 3, self.image_size, self.image_size)
).unsqueeze(0)
input_ids = input_ids.unsqueeze(0) input_ids = input_ids.unsqueeze(0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment