Unverified Commit 6ff51862 authored by Isotr0py's avatar Isotr0py Committed by GitHub
Browse files

[Bugfix] Fix deepseek-vl2 inference with more than 2 images (#13818)

parent fa820741
...@@ -25,7 +25,8 @@ from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, ...@@ -25,7 +25,8 @@ from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize, MultiModalDataItems) ImageSize, MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement) BaseProcessingInfo, ProcessingCache,
PromptReplacement)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.deepseek_vl2 import (DeepseekVLV2Config, from vllm.transformers_utils.configs.deepseek_vl2 import (DeepseekVLV2Config,
...@@ -138,18 +139,24 @@ class DeepseekVL2ProcessingInfo(BaseProcessingInfo): ...@@ -138,18 +139,24 @@ class DeepseekVL2ProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None} return {"image": None}
def get_num_image_tokens(self, *, image_width: int, def get_num_image_tokens(self,
image_height: int) -> int: *,
image_width: int,
image_height: int,
cropping: bool = True) -> int:
hf_processor = self.get_hf_processor() hf_processor = self.get_hf_processor()
image_size = hf_processor.image_size image_size = hf_processor.image_size
patch_size = hf_processor.patch_size patch_size = hf_processor.patch_size
downsample_ratio = hf_processor.downsample_ratio downsample_ratio = hf_processor.downsample_ratio
if cropping:
best_width, best_height = hf_processor.select_best_resolution( best_width, best_height = hf_processor.select_best_resolution(
(image_width, image_height)) (image_width, image_height))
num_width_tiles, num_height_tiles = (best_width // image_size, num_width_tiles, num_height_tiles = (best_width // image_size,
best_height // image_size) best_height // image_size)
else:
num_width_tiles = num_height_tiles = 1
h = w = math.ceil((image_size // patch_size) / downsample_ratio) h = w = math.ceil((image_size // patch_size) / downsample_ratio)
global_views_tokens = h * (w + 1) global_views_tokens = h * (w + 1)
...@@ -169,10 +176,12 @@ class DeepseekVL2ProcessingInfo(BaseProcessingInfo): ...@@ -169,10 +176,12 @@ class DeepseekVL2ProcessingInfo(BaseProcessingInfo):
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
) -> Mapping[str, int]: ) -> Mapping[str, int]:
num_images = mm_counts.get("image", 0)
max_image_size = self.get_image_size_with_most_features() max_image_size = self.get_image_size_with_most_features()
max_image_tokens = self.get_num_image_tokens( max_image_tokens = self.get_num_image_tokens(
image_height=max_image_size.height, image_height=max_image_size.height,
image_width=max_image_size.width) image_width=max_image_size.width,
cropping=num_images <= 2)
return {"image": max_image_tokens} return {"image": max_image_tokens}
...@@ -207,6 +216,30 @@ class DeepseekVL2DummyInputsBuilder( ...@@ -207,6 +216,30 @@ class DeepseekVL2DummyInputsBuilder(
class DeepseekVL2MultiModalProcessor( class DeepseekVL2MultiModalProcessor(
BaseMultiModalProcessor[DeepseekVL2ProcessingInfo]): BaseMultiModalProcessor[DeepseekVL2ProcessingInfo]):
def __init__(
self,
info: DeepseekVL2ProcessingInfo,
dummy_inputs: "BaseDummyInputsBuilder[DeepseekVL2ProcessingInfo]",
*,
cache: Optional[ProcessingCache] = None,
enable_sanity_checks: bool = True) -> None:
super().__init__(
info,
dummy_inputs,
cache=cache,
enable_sanity_checks=enable_sanity_checks,
)
mm_limit = self.info.ctx.model_config.multimodal_config.limit_per_prompt
if self.cache is not None and mm_limit["image"] > 2:
# The processor output depends on the number of images passed,
# making it incompatible with processing cache which is supposed
# to be invariant of how many images are passed per prompt
self.cache = None
logger.warning_once(
f"{type(self).__name__} does not support processing cache with "
"image limit larger than 2.")
def _call_hf_processor( def _call_hf_processor(
self, self,
prompt: str, prompt: str,
...@@ -271,6 +304,7 @@ class DeepseekVL2MultiModalProcessor( ...@@ -271,6 +304,7 @@ class DeepseekVL2MultiModalProcessor(
num_image_tokens = self.info.get_num_image_tokens( num_image_tokens = self.info.get_num_image_tokens(
image_width=image_size.width, image_width=image_size.width,
image_height=image_size.height, image_height=image_size.height,
cropping=len(images) <= 2,
) )
return [image_token_id] * num_image_tokens return [image_token_id] * num_image_tokens
......
...@@ -477,13 +477,15 @@ class H2OVLMultiModalProcessor(InternVLMultiModalProcessor[H2OVLProcessingInfo] ...@@ -477,13 +477,15 @@ class H2OVLMultiModalProcessor(InternVLMultiModalProcessor[H2OVLProcessingInfo]
enable_sanity_checks=enable_sanity_checks, enable_sanity_checks=enable_sanity_checks,
) )
if self.cache is not None: mm_limit = self.info.ctx.model_config.multimodal_config.limit_per_prompt
if self.cache is not None and mm_limit["image"] >= 2:
# The processor output depends on the number of images passed, # The processor output depends on the number of images passed,
# making it incompatible with processing cache which is supposed # making it incompatible with processing cache which is supposed
# to be invariant of how many images are passed per prompt # to be invariant of how many images are passed per prompt
self.cache = None self.cache = None
logger.warning_once( logger.warning_once(
f"{type(self).__name__} does not support processing cache.") f"{type(self).__name__} does not support processing cache with "
"multi-image support enabled.")
def _get_prompt_replacements( def _get_prompt_replacements(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment