Unverified Commit fdea8ec1 authored by Alexander Matveev's avatar Alexander Matveev Committed by GitHub
Browse files

[V1] VLM - enable processor cache by default (#11305)


Signed-off-by: default avatarAlexander Matveev <alexm@neuralmagic.com>
parent ca5f54a9
...@@ -28,7 +28,7 @@ def run_aria(question: str, modality: str): ...@@ -28,7 +28,7 @@ def run_aria(question: str, modality: str):
tokenizer_mode="slow", tokenizer_mode="slow",
trust_remote_code=True, trust_remote_code=True,
dtype="bfloat16", dtype="bfloat16",
mm_cache_preprocessor=args.mm_cache_preprocessor) disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
prompt = (f"<|im_start|>user\n<fim_prefix><|img|><fim_suffix>\n{question}" prompt = (f"<|im_start|>user\n<fim_prefix><|img|><fim_suffix>\n{question}"
"<|im_end|>\n<|im_start|>assistant\n") "<|im_end|>\n<|im_start|>assistant\n")
...@@ -45,7 +45,7 @@ def run_blip2(question: str, modality: str): ...@@ -45,7 +45,7 @@ def run_blip2(question: str, modality: str):
# See https://huggingface.co/Salesforce/blip2-opt-2.7b/discussions/15#64ff02f3f8cf9e4f5b038262 #noqa # See https://huggingface.co/Salesforce/blip2-opt-2.7b/discussions/15#64ff02f3f8cf9e4f5b038262 #noqa
prompt = f"Question: {question} Answer:" prompt = f"Question: {question} Answer:"
llm = LLM(model="Salesforce/blip2-opt-2.7b", llm = LLM(model="Salesforce/blip2-opt-2.7b",
mm_cache_preprocessor=args.mm_cache_preprocessor) disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
stop_token_ids = None stop_token_ids = None
return llm, prompt, stop_token_ids return llm, prompt, stop_token_ids
...@@ -57,7 +57,7 @@ def run_chameleon(question: str, modality: str): ...@@ -57,7 +57,7 @@ def run_chameleon(question: str, modality: str):
prompt = f"{question}<image>" prompt = f"{question}<image>"
llm = LLM(model="facebook/chameleon-7b", llm = LLM(model="facebook/chameleon-7b",
max_model_len=4096, max_model_len=4096,
mm_cache_preprocessor=args.mm_cache_preprocessor) disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
stop_token_ids = None stop_token_ids = None
return llm, prompt, stop_token_ids return llm, prompt, stop_token_ids
...@@ -70,7 +70,7 @@ def run_fuyu(question: str, modality: str): ...@@ -70,7 +70,7 @@ def run_fuyu(question: str, modality: str):
llm = LLM(model="adept/fuyu-8b", llm = LLM(model="adept/fuyu-8b",
max_model_len=2048, max_model_len=2048,
max_num_seqs=2, max_num_seqs=2,
mm_cache_preprocessor=args.mm_cache_preprocessor) disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
stop_token_ids = None stop_token_ids = None
return llm, prompt, stop_token_ids return llm, prompt, stop_token_ids
...@@ -85,7 +85,7 @@ def run_glm4v(question: str, modality: str): ...@@ -85,7 +85,7 @@ def run_glm4v(question: str, modality: str):
max_num_seqs=2, max_num_seqs=2,
trust_remote_code=True, trust_remote_code=True,
enforce_eager=True, enforce_eager=True,
mm_cache_preprocessor=args.mm_cache_preprocessor) disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
prompt = question prompt = question
stop_token_ids = [151329, 151336, 151338] stop_token_ids = [151329, 151336, 151338]
return llm, prompt, stop_token_ids return llm, prompt, stop_token_ids
...@@ -101,7 +101,7 @@ def run_h2ovl(question: str, modality: str): ...@@ -101,7 +101,7 @@ def run_h2ovl(question: str, modality: str):
model=model_name, model=model_name,
trust_remote_code=True, trust_remote_code=True,
max_model_len=8192, max_model_len=8192,
mm_cache_preprocessor=args.mm_cache_preprocessor, disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
) )
tokenizer = AutoTokenizer.from_pretrained(model_name, tokenizer = AutoTokenizer.from_pretrained(model_name,
...@@ -134,7 +134,7 @@ def run_idefics3(question: str, modality: str): ...@@ -134,7 +134,7 @@ def run_idefics3(question: str, modality: str):
"longest_edge": 3 * 364 "longest_edge": 3 * 364
}, },
}, },
mm_cache_preprocessor=args.mm_cache_preprocessor, disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
) )
prompt = ( prompt = (
f"<|begin_of_text|>User:<image>{question}<end_of_utterance>\nAssistant:" f"<|begin_of_text|>User:<image>{question}<end_of_utterance>\nAssistant:"
...@@ -153,7 +153,7 @@ def run_internvl(question: str, modality: str): ...@@ -153,7 +153,7 @@ def run_internvl(question: str, modality: str):
model=model_name, model=model_name,
trust_remote_code=True, trust_remote_code=True,
max_model_len=4096, max_model_len=4096,
mm_cache_preprocessor=args.mm_cache_preprocessor, disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
) )
tokenizer = AutoTokenizer.from_pretrained(model_name, tokenizer = AutoTokenizer.from_pretrained(model_name,
...@@ -180,7 +180,7 @@ def run_llava(question: str, modality: str): ...@@ -180,7 +180,7 @@ def run_llava(question: str, modality: str):
llm = LLM(model="llava-hf/llava-1.5-7b-hf", llm = LLM(model="llava-hf/llava-1.5-7b-hf",
max_model_len=4096, max_model_len=4096,
mm_cache_preprocessor=args.mm_cache_preprocessor) disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
stop_token_ids = None stop_token_ids = None
return llm, prompt, stop_token_ids return llm, prompt, stop_token_ids
...@@ -192,7 +192,7 @@ def run_llava_next(question: str, modality: str): ...@@ -192,7 +192,7 @@ def run_llava_next(question: str, modality: str):
prompt = f"[INST] <image>\n{question} [/INST]" prompt = f"[INST] <image>\n{question} [/INST]"
llm = LLM(model="llava-hf/llava-v1.6-mistral-7b-hf", llm = LLM(model="llava-hf/llava-v1.6-mistral-7b-hf",
max_model_len=8192, max_model_len=8192,
mm_cache_preprocessor=args.mm_cache_preprocessor) disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
stop_token_ids = None stop_token_ids = None
return llm, prompt, stop_token_ids return llm, prompt, stop_token_ids
...@@ -205,7 +205,7 @@ def run_llava_next_video(question: str, modality: str): ...@@ -205,7 +205,7 @@ def run_llava_next_video(question: str, modality: str):
prompt = f"USER: <video>\n{question} ASSISTANT:" prompt = f"USER: <video>\n{question} ASSISTANT:"
llm = LLM(model="llava-hf/LLaVA-NeXT-Video-7B-hf", llm = LLM(model="llava-hf/LLaVA-NeXT-Video-7B-hf",
max_model_len=8192, max_model_len=8192,
mm_cache_preprocessor=args.mm_cache_preprocessor) disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
stop_token_ids = None stop_token_ids = None
return llm, prompt, stop_token_ids return llm, prompt, stop_token_ids
...@@ -223,7 +223,7 @@ def run_llava_onevision(question: str, modality: str): ...@@ -223,7 +223,7 @@ def run_llava_onevision(question: str, modality: str):
llm = LLM(model="llava-hf/llava-onevision-qwen2-7b-ov-hf", llm = LLM(model="llava-hf/llava-onevision-qwen2-7b-ov-hf",
max_model_len=16384, max_model_len=16384,
mm_cache_preprocessor=args.mm_cache_preprocessor) disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
stop_token_ids = None stop_token_ids = None
return llm, prompt, stop_token_ids return llm, prompt, stop_token_ids
...@@ -239,7 +239,7 @@ def run_mantis(question: str, modality: str): ...@@ -239,7 +239,7 @@ def run_mantis(question: str, modality: str):
model="TIGER-Lab/Mantis-8B-siglip-llama3", model="TIGER-Lab/Mantis-8B-siglip-llama3",
max_model_len=4096, max_model_len=4096,
hf_overrides={"architectures": ["MantisForConditionalGeneration"]}, hf_overrides={"architectures": ["MantisForConditionalGeneration"]},
mm_cache_preprocessor=args.mm_cache_preprocessor, disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
) )
stop_token_ids = [128009] stop_token_ids = [128009]
return llm, prompt, stop_token_ids return llm, prompt, stop_token_ids
...@@ -266,7 +266,7 @@ def run_minicpmv(question: str, modality: str): ...@@ -266,7 +266,7 @@ def run_minicpmv(question: str, modality: str):
max_model_len=4096, max_model_len=4096,
max_num_seqs=2, max_num_seqs=2,
trust_remote_code=True, trust_remote_code=True,
mm_cache_preprocessor=args.mm_cache_preprocessor, disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
) )
# NOTE The stop_token_ids are different for various versions of MiniCPM-V # NOTE The stop_token_ids are different for various versions of MiniCPM-V
# 2.0 # 2.0
...@@ -305,7 +305,7 @@ def run_mllama(question: str, modality: str): ...@@ -305,7 +305,7 @@ def run_mllama(question: str, modality: str):
max_model_len=4096, max_model_len=4096,
max_num_seqs=16, max_num_seqs=16,
enforce_eager=True, enforce_eager=True,
mm_cache_preprocessor=args.mm_cache_preprocessor, disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
) )
prompt = f"<|image|><|begin_of_text|>{question}" prompt = f"<|image|><|begin_of_text|>{question}"
...@@ -323,7 +323,7 @@ def run_molmo(question, modality): ...@@ -323,7 +323,7 @@ def run_molmo(question, modality):
model=model_name, model=model_name,
trust_remote_code=True, trust_remote_code=True,
dtype="bfloat16", dtype="bfloat16",
mm_cache_preprocessor=args.mm_cache_preprocessor, disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
) )
prompt = question prompt = question
...@@ -343,7 +343,7 @@ def run_nvlm_d(question: str, modality: str): ...@@ -343,7 +343,7 @@ def run_nvlm_d(question: str, modality: str):
trust_remote_code=True, trust_remote_code=True,
max_model_len=4096, max_model_len=4096,
tensor_parallel_size=4, tensor_parallel_size=4,
mm_cache_preprocessor=args.mm_cache_preprocessor, disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
) )
tokenizer = AutoTokenizer.from_pretrained(model_name, tokenizer = AutoTokenizer.from_pretrained(model_name,
...@@ -363,7 +363,7 @@ def run_paligemma(question: str, modality: str): ...@@ -363,7 +363,7 @@ def run_paligemma(question: str, modality: str):
# PaliGemma has special prompt format for VQA # PaliGemma has special prompt format for VQA
prompt = "caption en" prompt = "caption en"
llm = LLM(model="google/paligemma-3b-mix-224", llm = LLM(model="google/paligemma-3b-mix-224",
mm_cache_preprocessor=args.mm_cache_preprocessor) disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
stop_token_ids = None stop_token_ids = None
return llm, prompt, stop_token_ids return llm, prompt, stop_token_ids
...@@ -375,7 +375,7 @@ def run_paligemma2(question: str, modality: str): ...@@ -375,7 +375,7 @@ def run_paligemma2(question: str, modality: str):
# PaliGemma 2 has special prompt format for VQA # PaliGemma 2 has special prompt format for VQA
prompt = "caption en" prompt = "caption en"
llm = LLM(model="google/paligemma2-3b-ft-docci-448", llm = LLM(model="google/paligemma2-3b-ft-docci-448",
mm_cache_preprocessor=args.mm_cache_preprocessor) disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
stop_token_ids = None stop_token_ids = None
return llm, prompt, stop_token_ids return llm, prompt, stop_token_ids
...@@ -405,7 +405,7 @@ def run_phi3v(question: str, modality: str): ...@@ -405,7 +405,7 @@ def run_phi3v(question: str, modality: str):
max_num_seqs=2, max_num_seqs=2,
# Note - mm_processor_kwargs can also be passed to generate/chat calls # Note - mm_processor_kwargs can also be passed to generate/chat calls
mm_processor_kwargs={"num_crops": 16}, mm_processor_kwargs={"num_crops": 16},
mm_cache_preprocessor=args.mm_cache_preprocessor, disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
) )
stop_token_ids = None stop_token_ids = None
return llm, prompt, stop_token_ids return llm, prompt, stop_token_ids
...@@ -420,7 +420,7 @@ def run_pixtral_hf(question: str, modality: str): ...@@ -420,7 +420,7 @@ def run_pixtral_hf(question: str, modality: str):
llm = LLM( llm = LLM(
model=model_name, model=model_name,
max_model_len=8192, max_model_len=8192,
mm_cache_preprocessor=args.mm_cache_preprocessor, disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
) )
prompt = f"<s>[INST]{question}\n[IMG][/INST]" prompt = f"<s>[INST]{question}\n[IMG][/INST]"
...@@ -437,7 +437,7 @@ def run_qwen_vl(question: str, modality: str): ...@@ -437,7 +437,7 @@ def run_qwen_vl(question: str, modality: str):
trust_remote_code=True, trust_remote_code=True,
max_model_len=1024, max_model_len=1024,
max_num_seqs=2, max_num_seqs=2,
mm_cache_preprocessor=args.mm_cache_preprocessor, disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
) )
prompt = f"{question}Picture 1: <img></img>\n" prompt = f"{question}Picture 1: <img></img>\n"
...@@ -460,7 +460,7 @@ def run_qwen2_vl(question: str, modality: str): ...@@ -460,7 +460,7 @@ def run_qwen2_vl(question: str, modality: str):
"min_pixels": 28 * 28, "min_pixels": 28 * 28,
"max_pixels": 1280 * 28 * 28, "max_pixels": 1280 * 28 * 28,
}, },
mm_cache_preprocessor=args.mm_cache_preprocessor, disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
) )
prompt = ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" prompt = ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
...@@ -651,9 +651,9 @@ if __name__ == "__main__": ...@@ -651,9 +651,9 @@ if __name__ == "__main__":
' (if enabled)') ' (if enabled)')
parser.add_argument( parser.add_argument(
'--mm-cache-preprocessor', '--disable-mm-preprocessor-cache',
action='store_true', action='store_true',
help='If True, enable caching of multi-modal preprocessor/mapper.') help='If True, disables caching of multi-modal preprocessor/mapper.')
parser.add_argument( parser.add_argument(
'--time-generate', '--time-generate',
......
...@@ -148,9 +148,8 @@ class ModelConfig: ...@@ -148,9 +148,8 @@ class ModelConfig:
HuggingFace config. HuggingFace config.
mm_processor_kwargs: Arguments to be forwarded to the model's processor mm_processor_kwargs: Arguments to be forwarded to the model's processor
for multi-modal data, e.g., image processor. for multi-modal data, e.g., image processor.
mm_cache_preprocessor: If true, then enables caching of the multi-modal disable_mm_preprocessor_cache: If true, then disables caching of the
preprocessor/mapper. Otherwise, the mapper executes each time, and multi-modal preprocessor/mapper. (not recommended)
for better performance consider enabling frontend process.
override_neuron_config: Initialize non default neuron config or override_neuron_config: Initialize non default neuron config or
override default neuron config that are specific to Neuron devices, override default neuron config that are specific to Neuron devices,
this argument will be used to configure the neuron config that this argument will be used to configure the neuron config that
...@@ -216,7 +215,7 @@ class ModelConfig: ...@@ -216,7 +215,7 @@ class ModelConfig:
config_format: ConfigFormat = ConfigFormat.AUTO, config_format: ConfigFormat = ConfigFormat.AUTO,
hf_overrides: Optional[HfOverrides] = None, hf_overrides: Optional[HfOverrides] = None,
mm_processor_kwargs: Optional[Dict[str, Any]] = None, mm_processor_kwargs: Optional[Dict[str, Any]] = None,
mm_cache_preprocessor: bool = False, disable_mm_preprocessor_cache: bool = False,
override_neuron_config: Optional[Dict[str, Any]] = None, override_neuron_config: Optional[Dict[str, Any]] = None,
override_pooler_config: Optional["PoolerConfig"] = None, override_pooler_config: Optional["PoolerConfig"] = None,
logits_processor_pattern: Optional[str] = None) -> None: logits_processor_pattern: Optional[str] = None) -> None:
...@@ -286,7 +285,7 @@ class ModelConfig: ...@@ -286,7 +285,7 @@ class ModelConfig:
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
self.use_async_output_proc = use_async_output_proc self.use_async_output_proc = use_async_output_proc
self.mm_processor_kwargs = mm_processor_kwargs self.mm_processor_kwargs = mm_processor_kwargs
self.mm_cache_preprocessor = mm_cache_preprocessor self.disable_mm_preprocessor_cache = disable_mm_preprocessor_cache
# Set enforce_eager to False if the value is unset. # Set enforce_eager to False if the value is unset.
if self.enforce_eager is None: if self.enforce_eager is None:
...@@ -3155,7 +3154,7 @@ class VllmConfig: ...@@ -3155,7 +3154,7 @@ class VllmConfig:
f"enable_prefix_caching={self.cache_config.enable_prefix_caching}, " f"enable_prefix_caching={self.cache_config.enable_prefix_caching}, "
f"chunked_prefill_enabled={self.scheduler_config.chunked_prefill_enabled}, " # noqa f"chunked_prefill_enabled={self.scheduler_config.chunked_prefill_enabled}, " # noqa
f"use_async_output_proc={self.model_config.use_async_output_proc}, " f"use_async_output_proc={self.model_config.use_async_output_proc}, "
f"mm_cache_preprocessor={self.model_config.mm_cache_preprocessor!r}, " # noqa f"disable_mm_preprocessor_cache={self.model_config.disable_mm_preprocessor_cache!r}, " # noqa
f"mm_processor_kwargs={self.model_config.mm_processor_kwargs}, " f"mm_processor_kwargs={self.model_config.mm_processor_kwargs}, "
f"pooler_config={self.model_config.pooler_config!r}, " f"pooler_config={self.model_config.pooler_config!r}, "
f"compilation_config={self.compilation_config!r}") f"compilation_config={self.compilation_config!r}")
......
...@@ -141,7 +141,7 @@ class EngineArgs: ...@@ -141,7 +141,7 @@ class EngineArgs:
tokenizer_pool_extra_config: Optional[Dict[str, Any]] = None tokenizer_pool_extra_config: Optional[Dict[str, Any]] = None
limit_mm_per_prompt: Optional[Mapping[str, int]] = None limit_mm_per_prompt: Optional[Mapping[str, int]] = None
mm_processor_kwargs: Optional[Dict[str, Any]] = None mm_processor_kwargs: Optional[Dict[str, Any]] = None
mm_cache_preprocessor: bool = False disable_mm_preprocessor_cache: bool = False
enable_lora: bool = False enable_lora: bool = False
enable_lora_bias: bool = False enable_lora_bias: bool = False
max_loras: int = 1 max_loras: int = 1
...@@ -606,11 +606,10 @@ class EngineArgs: ...@@ -606,11 +606,10 @@ class EngineArgs:
help=('Overrides for the multimodal input mapping/processing, ' help=('Overrides for the multimodal input mapping/processing, '
'e.g., image processor. For example: {"num_crops": 4}.')) 'e.g., image processor. For example: {"num_crops": 4}.'))
parser.add_argument( parser.add_argument(
'--mm-cache-preprocessor', '--disable-mm-preprocessor-cache',
action='store_true', action='store_true',
help='If true, then enables caching of the multi-modal ' help='If true, then disables caching of the multi-modal '
'preprocessor/mapper. Otherwise, the mapper executes each time' 'preprocessor/mapper. (not recommended)')
', and for better performance consider enabling frontend process.')
# LoRA related configs # LoRA related configs
parser.add_argument('--enable-lora', parser.add_argument('--enable-lora',
...@@ -983,7 +982,7 @@ class EngineArgs: ...@@ -983,7 +982,7 @@ class EngineArgs:
use_async_output_proc=not self.disable_async_output_proc, use_async_output_proc=not self.disable_async_output_proc,
config_format=self.config_format, config_format=self.config_format,
mm_processor_kwargs=self.mm_processor_kwargs, mm_processor_kwargs=self.mm_processor_kwargs,
mm_cache_preprocessor=self.mm_cache_preprocessor, disable_mm_preprocessor_cache=self.disable_mm_preprocessor_cache,
override_neuron_config=self.override_neuron_config, override_neuron_config=self.override_neuron_config,
override_pooler_config=self.override_pooler_config, override_pooler_config=self.override_pooler_config,
logits_processor_pattern=self.logits_processor_pattern) logits_processor_pattern=self.logits_processor_pattern)
......
...@@ -191,7 +191,7 @@ def generate_block_hash_extra_keys( ...@@ -191,7 +191,7 @@ def generate_block_hash_extra_keys(
raise ValueError( raise ValueError(
"The number of multi-modal positions and hashes must match. This " "The number of multi-modal positions and hashes must match. This "
"is likely because you do not enable MM preprocessor hashing. " "is likely because you do not enable MM preprocessor hashing. "
"Please set mm_cache_preprocessor=True.") "Please set disable_mm_preprocessor_cache=False.")
# Note that we assume mm_positions is sorted by offset. # Note that we assume mm_positions is sorted by offset.
# We do not need to check all mm inputs if the start token index is out of # We do not need to check all mm inputs if the start token index is out of
......
...@@ -43,7 +43,7 @@ class MMInputMapperClient: ...@@ -43,7 +43,7 @@ class MMInputMapperClient:
self.mm_registry.init_mm_limits_per_prompt(model_config) self.mm_registry.init_mm_limits_per_prompt(model_config)
# Init cache # Init cache
self.use_cache = model_config.mm_cache_preprocessor self.use_cache = not model_config.disable_mm_preprocessor_cache
self.mm_cache = LRUDictCache[str, MultiModalKwargs](MM_CACHE_SIZE) self.mm_cache = LRUDictCache[str, MultiModalKwargs](MM_CACHE_SIZE)
# DEBUG: Set to None to disable # DEBUG: Set to None to disable
...@@ -119,7 +119,7 @@ class MMInputMapperClient: ...@@ -119,7 +119,7 @@ class MMInputMapperClient:
class MMInputMapperServer: class MMInputMapperServer:
def __init__(self, model_config): def __init__(self, model_config):
self.use_cache = model_config.mm_cache_preprocessor self.use_cache = not model_config.disable_mm_preprocessor_cache
self.mm_cache = LRUDictCache[str, MultiModalKwargs](MM_CACHE_SIZE) self.mm_cache = LRUDictCache[str, MultiModalKwargs](MM_CACHE_SIZE)
def process_inputs( def process_inputs(
...@@ -151,12 +151,26 @@ class MMHasher: ...@@ -151,12 +151,26 @@ class MMHasher:
def __init__(self): def __init__(self):
pass pass
def hash(self, prompt: PromptType) -> Optional[List[str]]: def hash_mm_data(
self,
mm_data: Optional[MultiModalDataDict]) -> Optional[List[str]]:
if mm_data is None:
return None
image_inputs = mm_data['image']
return self.hash_images(image_inputs)
def hash_prompt(self, prompt: PromptType) -> Optional[List[str]]:
if "multi_modal_data" not in prompt: if "multi_modal_data" not in prompt:
return None return None
mm_data = prompt["multi_modal_data"] mm_data = prompt["multi_modal_data"]
image_inputs = mm_data["image"] image_inputs = mm_data["image"]
return self.hash_images(image_inputs)
def hash_images(self, image_inputs) -> Optional[List[str]]:
if not isinstance(image_inputs, list): if not isinstance(image_inputs, list):
image_inputs = [image_inputs] image_inputs = [image_inputs]
assert len(image_inputs) > 0 assert len(image_inputs) > 0
......
...@@ -46,7 +46,7 @@ class Processor: ...@@ -46,7 +46,7 @@ class Processor:
self.mm_input_mapper_client = MMInputMapperClient(model_config) self.mm_input_mapper_client = MMInputMapperClient(model_config)
# Multi-modal hasher (for images) # Multi-modal hasher (for images)
self.use_hash = model_config.mm_cache_preprocessor or \ self.use_hash = (not model_config.disable_mm_preprocessor_cache) or \
cache_config.enable_prefix_caching cache_config.enable_prefix_caching
self.mm_hasher = MMHasher() self.mm_hasher = MMHasher()
...@@ -80,7 +80,7 @@ class Processor: ...@@ -80,7 +80,7 @@ class Processor:
# Compute MM hashes (if enabled) # Compute MM hashes (if enabled)
mm_hashes = None mm_hashes = None
if self.use_hash: if self.use_hash:
mm_hashes = self.mm_hasher.hash(prompt) mm_hashes = self.mm_hasher.hash_prompt(prompt)
# Process inputs. # Process inputs.
preprocessed_inputs = self.input_preprocessor.preprocess( preprocessed_inputs = self.input_preprocessor.preprocess(
......
...@@ -19,7 +19,7 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, ...@@ -19,7 +19,7 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
LayerBlockType, cdiv, is_pin_memory_available) LayerBlockType, cdiv, is_pin_memory_available)
from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend, from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend,
FlashAttentionMetadata) FlashAttentionMetadata)
from vllm.v1.engine.mm_input_mapper import MMInputMapperClient from vllm.v1.engine.mm_input_mapper import MMHasher, MMInputMapperClient
from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
...@@ -79,8 +79,14 @@ class GPUModelRunner: ...@@ -79,8 +79,14 @@ class GPUModelRunner:
# Multi-modal data support # Multi-modal data support
self.input_registry = INPUT_REGISTRY self.input_registry = INPUT_REGISTRY
self.mm_registry = MULTIMODAL_REGISTRY self.mm_registry = MULTIMODAL_REGISTRY
# NOTE: mm_input_mapper is only used for memory profiling.
self.mm_input_mapper = MMInputMapperClient(self.model_config) # NOTE: mm_input_mapper_client and mm_hasher are only used for memory
# profiling.
self.mm_input_mapper_client = MMInputMapperClient(self.model_config)
self.mm_hasher = MMHasher()
self.use_hash = (not model_config.disable_mm_preprocessor_cache) or \
cache_config.enable_prefix_caching
self.max_num_encoder_input_tokens = self.scheduler_config.max_num_encoder_input_tokens # noqa: E501 self.max_num_encoder_input_tokens = self.scheduler_config.max_num_encoder_input_tokens # noqa: E501
self.encoder_cache_size = self.scheduler_config.encoder_cache_size self.encoder_cache_size = self.scheduler_config.encoder_cache_size
...@@ -628,9 +634,15 @@ class GPUModelRunner: ...@@ -628,9 +634,15 @@ class GPUModelRunner:
mm_registry=self.mm_registry, mm_registry=self.mm_registry,
) )
dummy_mm_data = dummy_request_data.multi_modal_data dummy_mm_data = dummy_request_data.multi_modal_data
dummy_mm_kwargs, _ = self.mm_input_mapper.process_inputs(
# Compute MM hashes (if enabled)
mm_hashes = None
if self.use_hash:
mm_hashes = self.mm_hasher.hash_mm_data(dummy_mm_data)
dummy_mm_kwargs = self.mm_input_mapper_client.process_inputs(
mm_data=dummy_mm_data, mm_data=dummy_mm_data,
mm_hashes=None, mm_hashes=mm_hashes,
mm_processor_kwargs=None, mm_processor_kwargs=None,
precomputed_mm_inputs=None) precomputed_mm_inputs=None)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment