Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
fdea8ec1
Unverified
Commit
fdea8ec1
authored
Dec 18, 2024
by
Alexander Matveev
Committed by
GitHub
Dec 18, 2024
Browse files
[V1] VLM - enable processor cache by default (#11305)
Signed-off-by:
Alexander Matveev
<
alexm@neuralmagic.com
>
parent
ca5f54a9
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
72 additions
and
48 deletions
+72
-48
examples/offline_inference_vision_language.py
examples/offline_inference_vision_language.py
+25
-25
vllm/config.py
vllm/config.py
+5
-6
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+5
-6
vllm/v1/core/kv_cache_utils.py
vllm/v1/core/kv_cache_utils.py
+1
-1
vllm/v1/engine/mm_input_mapper.py
vllm/v1/engine/mm_input_mapper.py
+17
-3
vllm/v1/engine/processor.py
vllm/v1/engine/processor.py
+2
-2
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+17
-5
No files found.
examples/offline_inference_vision_language.py
View file @
fdea8ec1
...
@@ -28,7 +28,7 @@ def run_aria(question: str, modality: str):
...
@@ -28,7 +28,7 @@ def run_aria(question: str, modality: str):
tokenizer_mode
=
"slow"
,
tokenizer_mode
=
"slow"
,
trust_remote_code
=
True
,
trust_remote_code
=
True
,
dtype
=
"bfloat16"
,
dtype
=
"bfloat16"
,
mm_cache
_preprocessor
=
args
.
mm_cache
_preprocessor
)
disable_mm
_preprocessor
_cache
=
args
.
disable_mm
_preprocessor
_cache
)
prompt
=
(
f
"<|im_start|>user
\n
<fim_prefix><|img|><fim_suffix>
\n
{
question
}
"
prompt
=
(
f
"<|im_start|>user
\n
<fim_prefix><|img|><fim_suffix>
\n
{
question
}
"
"<|im_end|>
\n
<|im_start|>assistant
\n
"
)
"<|im_end|>
\n
<|im_start|>assistant
\n
"
)
...
@@ -45,7 +45,7 @@ def run_blip2(question: str, modality: str):
...
@@ -45,7 +45,7 @@ def run_blip2(question: str, modality: str):
# See https://huggingface.co/Salesforce/blip2-opt-2.7b/discussions/15#64ff02f3f8cf9e4f5b038262 #noqa
# See https://huggingface.co/Salesforce/blip2-opt-2.7b/discussions/15#64ff02f3f8cf9e4f5b038262 #noqa
prompt
=
f
"Question:
{
question
}
Answer:"
prompt
=
f
"Question:
{
question
}
Answer:"
llm
=
LLM
(
model
=
"Salesforce/blip2-opt-2.7b"
,
llm
=
LLM
(
model
=
"Salesforce/blip2-opt-2.7b"
,
mm_cache
_preprocessor
=
args
.
mm_cache
_preprocessor
)
disable_mm
_preprocessor
_cache
=
args
.
disable_mm
_preprocessor
_cache
)
stop_token_ids
=
None
stop_token_ids
=
None
return
llm
,
prompt
,
stop_token_ids
return
llm
,
prompt
,
stop_token_ids
...
@@ -57,7 +57,7 @@ def run_chameleon(question: str, modality: str):
...
@@ -57,7 +57,7 @@ def run_chameleon(question: str, modality: str):
prompt
=
f
"
{
question
}
<image>"
prompt
=
f
"
{
question
}
<image>"
llm
=
LLM
(
model
=
"facebook/chameleon-7b"
,
llm
=
LLM
(
model
=
"facebook/chameleon-7b"
,
max_model_len
=
4096
,
max_model_len
=
4096
,
mm_cache
_preprocessor
=
args
.
mm_cache
_preprocessor
)
disable_mm
_preprocessor
_cache
=
args
.
disable_mm
_preprocessor
_cache
)
stop_token_ids
=
None
stop_token_ids
=
None
return
llm
,
prompt
,
stop_token_ids
return
llm
,
prompt
,
stop_token_ids
...
@@ -70,7 +70,7 @@ def run_fuyu(question: str, modality: str):
...
@@ -70,7 +70,7 @@ def run_fuyu(question: str, modality: str):
llm
=
LLM
(
model
=
"adept/fuyu-8b"
,
llm
=
LLM
(
model
=
"adept/fuyu-8b"
,
max_model_len
=
2048
,
max_model_len
=
2048
,
max_num_seqs
=
2
,
max_num_seqs
=
2
,
mm_cache
_preprocessor
=
args
.
mm_cache
_preprocessor
)
disable_mm
_preprocessor
_cache
=
args
.
disable_mm
_preprocessor
_cache
)
stop_token_ids
=
None
stop_token_ids
=
None
return
llm
,
prompt
,
stop_token_ids
return
llm
,
prompt
,
stop_token_ids
...
@@ -85,7 +85,7 @@ def run_glm4v(question: str, modality: str):
...
@@ -85,7 +85,7 @@ def run_glm4v(question: str, modality: str):
max_num_seqs
=
2
,
max_num_seqs
=
2
,
trust_remote_code
=
True
,
trust_remote_code
=
True
,
enforce_eager
=
True
,
enforce_eager
=
True
,
mm_cache
_preprocessor
=
args
.
mm_cache
_preprocessor
)
disable_mm
_preprocessor
_cache
=
args
.
disable_mm
_preprocessor
_cache
)
prompt
=
question
prompt
=
question
stop_token_ids
=
[
151329
,
151336
,
151338
]
stop_token_ids
=
[
151329
,
151336
,
151338
]
return
llm
,
prompt
,
stop_token_ids
return
llm
,
prompt
,
stop_token_ids
...
@@ -101,7 +101,7 @@ def run_h2ovl(question: str, modality: str):
...
@@ -101,7 +101,7 @@ def run_h2ovl(question: str, modality: str):
model
=
model_name
,
model
=
model_name
,
trust_remote_code
=
True
,
trust_remote_code
=
True
,
max_model_len
=
8192
,
max_model_len
=
8192
,
mm_cache
_preprocessor
=
args
.
mm_cache
_preprocessor
,
disable_mm
_preprocessor
_cache
=
args
.
disable_mm
_preprocessor
_cache
,
)
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
,
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
,
...
@@ -134,7 +134,7 @@ def run_idefics3(question: str, modality: str):
...
@@ -134,7 +134,7 @@ def run_idefics3(question: str, modality: str):
"longest_edge"
:
3
*
364
"longest_edge"
:
3
*
364
},
},
},
},
mm_cache
_preprocessor
=
args
.
mm_cache
_preprocessor
,
disable_mm
_preprocessor
_cache
=
args
.
disable_mm
_preprocessor
_cache
,
)
)
prompt
=
(
prompt
=
(
f
"<|begin_of_text|>User:<image>
{
question
}
<end_of_utterance>
\n
Assistant:"
f
"<|begin_of_text|>User:<image>
{
question
}
<end_of_utterance>
\n
Assistant:"
...
@@ -153,7 +153,7 @@ def run_internvl(question: str, modality: str):
...
@@ -153,7 +153,7 @@ def run_internvl(question: str, modality: str):
model
=
model_name
,
model
=
model_name
,
trust_remote_code
=
True
,
trust_remote_code
=
True
,
max_model_len
=
4096
,
max_model_len
=
4096
,
mm_cache
_preprocessor
=
args
.
mm_cache
_preprocessor
,
disable_mm
_preprocessor
_cache
=
args
.
disable_mm
_preprocessor
_cache
,
)
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
,
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
,
...
@@ -180,7 +180,7 @@ def run_llava(question: str, modality: str):
...
@@ -180,7 +180,7 @@ def run_llava(question: str, modality: str):
llm
=
LLM
(
model
=
"llava-hf/llava-1.5-7b-hf"
,
llm
=
LLM
(
model
=
"llava-hf/llava-1.5-7b-hf"
,
max_model_len
=
4096
,
max_model_len
=
4096
,
mm_cache
_preprocessor
=
args
.
mm_cache
_preprocessor
)
disable_mm
_preprocessor
_cache
=
args
.
disable_mm
_preprocessor
_cache
)
stop_token_ids
=
None
stop_token_ids
=
None
return
llm
,
prompt
,
stop_token_ids
return
llm
,
prompt
,
stop_token_ids
...
@@ -192,7 +192,7 @@ def run_llava_next(question: str, modality: str):
...
@@ -192,7 +192,7 @@ def run_llava_next(question: str, modality: str):
prompt
=
f
"[INST] <image>
\n
{
question
}
[/INST]"
prompt
=
f
"[INST] <image>
\n
{
question
}
[/INST]"
llm
=
LLM
(
model
=
"llava-hf/llava-v1.6-mistral-7b-hf"
,
llm
=
LLM
(
model
=
"llava-hf/llava-v1.6-mistral-7b-hf"
,
max_model_len
=
8192
,
max_model_len
=
8192
,
mm_cache
_preprocessor
=
args
.
mm_cache
_preprocessor
)
disable_mm
_preprocessor
_cache
=
args
.
disable_mm
_preprocessor
_cache
)
stop_token_ids
=
None
stop_token_ids
=
None
return
llm
,
prompt
,
stop_token_ids
return
llm
,
prompt
,
stop_token_ids
...
@@ -205,7 +205,7 @@ def run_llava_next_video(question: str, modality: str):
...
@@ -205,7 +205,7 @@ def run_llava_next_video(question: str, modality: str):
prompt
=
f
"USER: <video>
\n
{
question
}
ASSISTANT:"
prompt
=
f
"USER: <video>
\n
{
question
}
ASSISTANT:"
llm
=
LLM
(
model
=
"llava-hf/LLaVA-NeXT-Video-7B-hf"
,
llm
=
LLM
(
model
=
"llava-hf/LLaVA-NeXT-Video-7B-hf"
,
max_model_len
=
8192
,
max_model_len
=
8192
,
mm_cache
_preprocessor
=
args
.
mm_cache
_preprocessor
)
disable_mm
_preprocessor
_cache
=
args
.
disable_mm
_preprocessor
_cache
)
stop_token_ids
=
None
stop_token_ids
=
None
return
llm
,
prompt
,
stop_token_ids
return
llm
,
prompt
,
stop_token_ids
...
@@ -223,7 +223,7 @@ def run_llava_onevision(question: str, modality: str):
...
@@ -223,7 +223,7 @@ def run_llava_onevision(question: str, modality: str):
llm
=
LLM
(
model
=
"llava-hf/llava-onevision-qwen2-7b-ov-hf"
,
llm
=
LLM
(
model
=
"llava-hf/llava-onevision-qwen2-7b-ov-hf"
,
max_model_len
=
16384
,
max_model_len
=
16384
,
mm_cache
_preprocessor
=
args
.
mm_cache
_preprocessor
)
disable_mm
_preprocessor
_cache
=
args
.
disable_mm
_preprocessor
_cache
)
stop_token_ids
=
None
stop_token_ids
=
None
return
llm
,
prompt
,
stop_token_ids
return
llm
,
prompt
,
stop_token_ids
...
@@ -239,7 +239,7 @@ def run_mantis(question: str, modality: str):
...
@@ -239,7 +239,7 @@ def run_mantis(question: str, modality: str):
model
=
"TIGER-Lab/Mantis-8B-siglip-llama3"
,
model
=
"TIGER-Lab/Mantis-8B-siglip-llama3"
,
max_model_len
=
4096
,
max_model_len
=
4096
,
hf_overrides
=
{
"architectures"
:
[
"MantisForConditionalGeneration"
]},
hf_overrides
=
{
"architectures"
:
[
"MantisForConditionalGeneration"
]},
mm_cache
_preprocessor
=
args
.
mm_cache
_preprocessor
,
disable_mm
_preprocessor
_cache
=
args
.
disable_mm
_preprocessor
_cache
,
)
)
stop_token_ids
=
[
128009
]
stop_token_ids
=
[
128009
]
return
llm
,
prompt
,
stop_token_ids
return
llm
,
prompt
,
stop_token_ids
...
@@ -266,7 +266,7 @@ def run_minicpmv(question: str, modality: str):
...
@@ -266,7 +266,7 @@ def run_minicpmv(question: str, modality: str):
max_model_len
=
4096
,
max_model_len
=
4096
,
max_num_seqs
=
2
,
max_num_seqs
=
2
,
trust_remote_code
=
True
,
trust_remote_code
=
True
,
mm_cache
_preprocessor
=
args
.
mm_cache
_preprocessor
,
disable_mm
_preprocessor
_cache
=
args
.
disable_mm
_preprocessor
_cache
,
)
)
# NOTE The stop_token_ids are different for various versions of MiniCPM-V
# NOTE The stop_token_ids are different for various versions of MiniCPM-V
# 2.0
# 2.0
...
@@ -305,7 +305,7 @@ def run_mllama(question: str, modality: str):
...
@@ -305,7 +305,7 @@ def run_mllama(question: str, modality: str):
max_model_len
=
4096
,
max_model_len
=
4096
,
max_num_seqs
=
16
,
max_num_seqs
=
16
,
enforce_eager
=
True
,
enforce_eager
=
True
,
mm_cache
_preprocessor
=
args
.
mm_cache
_preprocessor
,
disable_mm
_preprocessor
_cache
=
args
.
disable_mm
_preprocessor
_cache
,
)
)
prompt
=
f
"<|image|><|begin_of_text|>
{
question
}
"
prompt
=
f
"<|image|><|begin_of_text|>
{
question
}
"
...
@@ -323,7 +323,7 @@ def run_molmo(question, modality):
...
@@ -323,7 +323,7 @@ def run_molmo(question, modality):
model
=
model_name
,
model
=
model_name
,
trust_remote_code
=
True
,
trust_remote_code
=
True
,
dtype
=
"bfloat16"
,
dtype
=
"bfloat16"
,
mm_cache
_preprocessor
=
args
.
mm_cache
_preprocessor
,
disable_mm
_preprocessor
_cache
=
args
.
disable_mm
_preprocessor
_cache
,
)
)
prompt
=
question
prompt
=
question
...
@@ -343,7 +343,7 @@ def run_nvlm_d(question: str, modality: str):
...
@@ -343,7 +343,7 @@ def run_nvlm_d(question: str, modality: str):
trust_remote_code
=
True
,
trust_remote_code
=
True
,
max_model_len
=
4096
,
max_model_len
=
4096
,
tensor_parallel_size
=
4
,
tensor_parallel_size
=
4
,
mm_cache
_preprocessor
=
args
.
mm_cache
_preprocessor
,
disable_mm
_preprocessor
_cache
=
args
.
disable_mm
_preprocessor
_cache
,
)
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
,
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
,
...
@@ -363,7 +363,7 @@ def run_paligemma(question: str, modality: str):
...
@@ -363,7 +363,7 @@ def run_paligemma(question: str, modality: str):
# PaliGemma has special prompt format for VQA
# PaliGemma has special prompt format for VQA
prompt
=
"caption en"
prompt
=
"caption en"
llm
=
LLM
(
model
=
"google/paligemma-3b-mix-224"
,
llm
=
LLM
(
model
=
"google/paligemma-3b-mix-224"
,
mm_cache
_preprocessor
=
args
.
mm_cache
_preprocessor
)
disable_mm
_preprocessor
_cache
=
args
.
disable_mm
_preprocessor
_cache
)
stop_token_ids
=
None
stop_token_ids
=
None
return
llm
,
prompt
,
stop_token_ids
return
llm
,
prompt
,
stop_token_ids
...
@@ -375,7 +375,7 @@ def run_paligemma2(question: str, modality: str):
...
@@ -375,7 +375,7 @@ def run_paligemma2(question: str, modality: str):
# PaliGemma 2 has special prompt format for VQA
# PaliGemma 2 has special prompt format for VQA
prompt
=
"caption en"
prompt
=
"caption en"
llm
=
LLM
(
model
=
"google/paligemma2-3b-ft-docci-448"
,
llm
=
LLM
(
model
=
"google/paligemma2-3b-ft-docci-448"
,
mm_cache
_preprocessor
=
args
.
mm_cache
_preprocessor
)
disable_mm
_preprocessor
_cache
=
args
.
disable_mm
_preprocessor
_cache
)
stop_token_ids
=
None
stop_token_ids
=
None
return
llm
,
prompt
,
stop_token_ids
return
llm
,
prompt
,
stop_token_ids
...
@@ -405,7 +405,7 @@ def run_phi3v(question: str, modality: str):
...
@@ -405,7 +405,7 @@ def run_phi3v(question: str, modality: str):
max_num_seqs
=
2
,
max_num_seqs
=
2
,
# Note - mm_processor_kwargs can also be passed to generate/chat calls
# Note - mm_processor_kwargs can also be passed to generate/chat calls
mm_processor_kwargs
=
{
"num_crops"
:
16
},
mm_processor_kwargs
=
{
"num_crops"
:
16
},
mm_cache
_preprocessor
=
args
.
mm_cache
_preprocessor
,
disable_mm
_preprocessor
_cache
=
args
.
disable_mm
_preprocessor
_cache
,
)
)
stop_token_ids
=
None
stop_token_ids
=
None
return
llm
,
prompt
,
stop_token_ids
return
llm
,
prompt
,
stop_token_ids
...
@@ -420,7 +420,7 @@ def run_pixtral_hf(question: str, modality: str):
...
@@ -420,7 +420,7 @@ def run_pixtral_hf(question: str, modality: str):
llm
=
LLM
(
llm
=
LLM
(
model
=
model_name
,
model
=
model_name
,
max_model_len
=
8192
,
max_model_len
=
8192
,
mm_cache
_preprocessor
=
args
.
mm_cache
_preprocessor
,
disable_mm
_preprocessor
_cache
=
args
.
disable_mm
_preprocessor
_cache
,
)
)
prompt
=
f
"<s>[INST]
{
question
}
\n
[IMG][/INST]"
prompt
=
f
"<s>[INST]
{
question
}
\n
[IMG][/INST]"
...
@@ -437,7 +437,7 @@ def run_qwen_vl(question: str, modality: str):
...
@@ -437,7 +437,7 @@ def run_qwen_vl(question: str, modality: str):
trust_remote_code
=
True
,
trust_remote_code
=
True
,
max_model_len
=
1024
,
max_model_len
=
1024
,
max_num_seqs
=
2
,
max_num_seqs
=
2
,
mm_cache
_preprocessor
=
args
.
mm_cache
_preprocessor
,
disable_mm
_preprocessor
_cache
=
args
.
disable_mm
_preprocessor
_cache
,
)
)
prompt
=
f
"
{
question
}
Picture 1: <img></img>
\n
"
prompt
=
f
"
{
question
}
Picture 1: <img></img>
\n
"
...
@@ -460,7 +460,7 @@ def run_qwen2_vl(question: str, modality: str):
...
@@ -460,7 +460,7 @@ def run_qwen2_vl(question: str, modality: str):
"min_pixels"
:
28
*
28
,
"min_pixels"
:
28
*
28
,
"max_pixels"
:
1280
*
28
*
28
,
"max_pixels"
:
1280
*
28
*
28
,
},
},
mm_cache
_preprocessor
=
args
.
mm_cache
_preprocessor
,
disable_mm
_preprocessor
_cache
=
args
.
disable_mm
_preprocessor
_cache
,
)
)
prompt
=
(
"<|im_start|>system
\n
You are a helpful assistant.<|im_end|>
\n
"
prompt
=
(
"<|im_start|>system
\n
You are a helpful assistant.<|im_end|>
\n
"
...
@@ -651,9 +651,9 @@ if __name__ == "__main__":
...
@@ -651,9 +651,9 @@ if __name__ == "__main__":
' (if enabled)'
)
' (if enabled)'
)
parser
.
add_argument
(
parser
.
add_argument
(
'--
mm-cache
-preprocessor'
,
'--
disable-mm
-preprocessor
-cache
'
,
action
=
'store_true'
,
action
=
'store_true'
,
help
=
'If True,
en
able caching of multi-modal preprocessor/mapper.'
)
help
=
'If True,
dis
able
s
caching of multi-modal preprocessor/mapper.'
)
parser
.
add_argument
(
parser
.
add_argument
(
'--time-generate'
,
'--time-generate'
,
...
...
vllm/config.py
View file @
fdea8ec1
...
@@ -148,9 +148,8 @@ class ModelConfig:
...
@@ -148,9 +148,8 @@ class ModelConfig:
HuggingFace config.
HuggingFace config.
mm_processor_kwargs: Arguments to be forwarded to the model's processor
mm_processor_kwargs: Arguments to be forwarded to the model's processor
for multi-modal data, e.g., image processor.
for multi-modal data, e.g., image processor.
mm_cache_preprocessor: If true, then enables caching of the multi-modal
disable_mm_preprocessor_cache: If true, then disables caching of the
preprocessor/mapper. Otherwise, the mapper executes each time, and
multi-modal preprocessor/mapper. (not recommended)
for better performance consider enabling frontend process.
override_neuron_config: Initialize non default neuron config or
override_neuron_config: Initialize non default neuron config or
override default neuron config that are specific to Neuron devices,
override default neuron config that are specific to Neuron devices,
this argument will be used to configure the neuron config that
this argument will be used to configure the neuron config that
...
@@ -216,7 +215,7 @@ class ModelConfig:
...
@@ -216,7 +215,7 @@ class ModelConfig:
config_format
:
ConfigFormat
=
ConfigFormat
.
AUTO
,
config_format
:
ConfigFormat
=
ConfigFormat
.
AUTO
,
hf_overrides
:
Optional
[
HfOverrides
]
=
None
,
hf_overrides
:
Optional
[
HfOverrides
]
=
None
,
mm_processor_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
mm_processor_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
mm_cache
_preprocessor
:
bool
=
False
,
disable_mm
_preprocessor
_cache
:
bool
=
False
,
override_neuron_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
override_neuron_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
override_pooler_config
:
Optional
[
"PoolerConfig"
]
=
None
,
override_pooler_config
:
Optional
[
"PoolerConfig"
]
=
None
,
logits_processor_pattern
:
Optional
[
str
]
=
None
)
->
None
:
logits_processor_pattern
:
Optional
[
str
]
=
None
)
->
None
:
...
@@ -286,7 +285,7 @@ class ModelConfig:
...
@@ -286,7 +285,7 @@ class ModelConfig:
self
.
dtype
=
_get_and_verify_dtype
(
self
.
hf_text_config
,
dtype
)
self
.
dtype
=
_get_and_verify_dtype
(
self
.
hf_text_config
,
dtype
)
self
.
use_async_output_proc
=
use_async_output_proc
self
.
use_async_output_proc
=
use_async_output_proc
self
.
mm_processor_kwargs
=
mm_processor_kwargs
self
.
mm_processor_kwargs
=
mm_processor_kwargs
self
.
mm_cache
_preprocessor
=
mm
_cache_preprocessor
self
.
disable_mm
_preprocessor_cache
=
disable_mm
_preprocessor
_cache
# Set enforce_eager to False if the value is unset.
# Set enforce_eager to False if the value is unset.
if
self
.
enforce_eager
is
None
:
if
self
.
enforce_eager
is
None
:
...
@@ -3155,7 +3154,7 @@ class VllmConfig:
...
@@ -3155,7 +3154,7 @@ class VllmConfig:
f
"enable_prefix_caching=
{
self
.
cache_config
.
enable_prefix_caching
}
, "
f
"enable_prefix_caching=
{
self
.
cache_config
.
enable_prefix_caching
}
, "
f
"chunked_prefill_enabled=
{
self
.
scheduler_config
.
chunked_prefill_enabled
}
, "
# noqa
f
"chunked_prefill_enabled=
{
self
.
scheduler_config
.
chunked_prefill_enabled
}
, "
# noqa
f
"use_async_output_proc=
{
self
.
model_config
.
use_async_output_proc
}
, "
f
"use_async_output_proc=
{
self
.
model_config
.
use_async_output_proc
}
, "
f
"
mm_cache
_preprocessor=
{
self
.
model_config
.
mm_cache
_preprocessor
!
r
}
, "
# noqa
f
"
disable_mm
_preprocessor
_cache
=
{
self
.
model_config
.
disable_mm
_preprocessor
_cache
!
r
}
, "
# noqa
f
"mm_processor_kwargs=
{
self
.
model_config
.
mm_processor_kwargs
}
, "
f
"mm_processor_kwargs=
{
self
.
model_config
.
mm_processor_kwargs
}
, "
f
"pooler_config=
{
self
.
model_config
.
pooler_config
!
r
}
, "
f
"pooler_config=
{
self
.
model_config
.
pooler_config
!
r
}
, "
f
"compilation_config=
{
self
.
compilation_config
!
r
}
"
)
f
"compilation_config=
{
self
.
compilation_config
!
r
}
"
)
...
...
vllm/engine/arg_utils.py
View file @
fdea8ec1
...
@@ -141,7 +141,7 @@ class EngineArgs:
...
@@ -141,7 +141,7 @@ class EngineArgs:
tokenizer_pool_extra_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
tokenizer_pool_extra_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
limit_mm_per_prompt
:
Optional
[
Mapping
[
str
,
int
]]
=
None
limit_mm_per_prompt
:
Optional
[
Mapping
[
str
,
int
]]
=
None
mm_processor_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
mm_processor_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
mm_cache
_preprocessor
:
bool
=
False
disable_mm
_preprocessor
_cache
:
bool
=
False
enable_lora
:
bool
=
False
enable_lora
:
bool
=
False
enable_lora_bias
:
bool
=
False
enable_lora_bias
:
bool
=
False
max_loras
:
int
=
1
max_loras
:
int
=
1
...
@@ -606,11 +606,10 @@ class EngineArgs:
...
@@ -606,11 +606,10 @@ class EngineArgs:
help
=
(
'Overrides for the multimodal input mapping/processing, '
help
=
(
'Overrides for the multimodal input mapping/processing, '
'e.g., image processor. For example: {"num_crops": 4}.'
))
'e.g., image processor. For example: {"num_crops": 4}.'
))
parser
.
add_argument
(
parser
.
add_argument
(
'--
mm-cache
-preprocessor'
,
'--
disable-mm
-preprocessor
-cache
'
,
action
=
'store_true'
,
action
=
'store_true'
,
help
=
'If true, then enables caching of the multi-modal '
help
=
'If true, then disables caching of the multi-modal '
'preprocessor/mapper. Otherwise, the mapper executes each time'
'preprocessor/mapper. (not recommended)'
)
', and for better performance consider enabling frontend process.'
)
# LoRA related configs
# LoRA related configs
parser
.
add_argument
(
'--enable-lora'
,
parser
.
add_argument
(
'--enable-lora'
,
...
@@ -983,7 +982,7 @@ class EngineArgs:
...
@@ -983,7 +982,7 @@ class EngineArgs:
use_async_output_proc
=
not
self
.
disable_async_output_proc
,
use_async_output_proc
=
not
self
.
disable_async_output_proc
,
config_format
=
self
.
config_format
,
config_format
=
self
.
config_format
,
mm_processor_kwargs
=
self
.
mm_processor_kwargs
,
mm_processor_kwargs
=
self
.
mm_processor_kwargs
,
mm_cache
_preprocessor
=
self
.
mm_cache
_preprocessor
,
disable_mm
_preprocessor
_cache
=
self
.
disable_mm
_preprocessor
_cache
,
override_neuron_config
=
self
.
override_neuron_config
,
override_neuron_config
=
self
.
override_neuron_config
,
override_pooler_config
=
self
.
override_pooler_config
,
override_pooler_config
=
self
.
override_pooler_config
,
logits_processor_pattern
=
self
.
logits_processor_pattern
)
logits_processor_pattern
=
self
.
logits_processor_pattern
)
...
...
vllm/v1/core/kv_cache_utils.py
View file @
fdea8ec1
...
@@ -191,7 +191,7 @@ def generate_block_hash_extra_keys(
...
@@ -191,7 +191,7 @@ def generate_block_hash_extra_keys(
raise
ValueError
(
raise
ValueError
(
"The number of multi-modal positions and hashes must match. This "
"The number of multi-modal positions and hashes must match. This "
"is likely because you do not enable MM preprocessor hashing. "
"is likely because you do not enable MM preprocessor hashing. "
"Please set
mm_cache
_preprocessor
=Tru
e."
)
"Please set
disable_mm
_preprocessor
_cache=Fals
e."
)
# Note that we assume mm_positions is sorted by offset.
# Note that we assume mm_positions is sorted by offset.
# We do not need to check all mm inputs if the start token index is out of
# We do not need to check all mm inputs if the start token index is out of
...
...
vllm/v1/engine/mm_input_mapper.py
View file @
fdea8ec1
...
@@ -43,7 +43,7 @@ class MMInputMapperClient:
...
@@ -43,7 +43,7 @@ class MMInputMapperClient:
self
.
mm_registry
.
init_mm_limits_per_prompt
(
model_config
)
self
.
mm_registry
.
init_mm_limits_per_prompt
(
model_config
)
# Init cache
# Init cache
self
.
use_cache
=
model_config
.
mm_cache
_preprocessor
self
.
use_cache
=
not
model_config
.
disable_mm
_preprocessor
_cache
self
.
mm_cache
=
LRUDictCache
[
str
,
MultiModalKwargs
](
MM_CACHE_SIZE
)
self
.
mm_cache
=
LRUDictCache
[
str
,
MultiModalKwargs
](
MM_CACHE_SIZE
)
# DEBUG: Set to None to disable
# DEBUG: Set to None to disable
...
@@ -119,7 +119,7 @@ class MMInputMapperClient:
...
@@ -119,7 +119,7 @@ class MMInputMapperClient:
class
MMInputMapperServer
:
class
MMInputMapperServer
:
def
__init__
(
self
,
model_config
):
def
__init__
(
self
,
model_config
):
self
.
use_cache
=
model_config
.
mm_cache
_preprocessor
self
.
use_cache
=
not
model_config
.
disable_mm
_preprocessor
_cache
self
.
mm_cache
=
LRUDictCache
[
str
,
MultiModalKwargs
](
MM_CACHE_SIZE
)
self
.
mm_cache
=
LRUDictCache
[
str
,
MultiModalKwargs
](
MM_CACHE_SIZE
)
def
process_inputs
(
def
process_inputs
(
...
@@ -151,12 +151,26 @@ class MMHasher:
...
@@ -151,12 +151,26 @@ class MMHasher:
def
__init__
(
self
):
def
__init__
(
self
):
pass
pass
def
hash
(
self
,
prompt
:
PromptType
)
->
Optional
[
List
[
str
]]:
def
hash_mm_data
(
self
,
mm_data
:
Optional
[
MultiModalDataDict
])
->
Optional
[
List
[
str
]]:
if
mm_data
is
None
:
return
None
image_inputs
=
mm_data
[
'image'
]
return
self
.
hash_images
(
image_inputs
)
def
hash_prompt
(
self
,
prompt
:
PromptType
)
->
Optional
[
List
[
str
]]:
if
"multi_modal_data"
not
in
prompt
:
if
"multi_modal_data"
not
in
prompt
:
return
None
return
None
mm_data
=
prompt
[
"multi_modal_data"
]
mm_data
=
prompt
[
"multi_modal_data"
]
image_inputs
=
mm_data
[
"image"
]
image_inputs
=
mm_data
[
"image"
]
return
self
.
hash_images
(
image_inputs
)
def
hash_images
(
self
,
image_inputs
)
->
Optional
[
List
[
str
]]:
if
not
isinstance
(
image_inputs
,
list
):
if
not
isinstance
(
image_inputs
,
list
):
image_inputs
=
[
image_inputs
]
image_inputs
=
[
image_inputs
]
assert
len
(
image_inputs
)
>
0
assert
len
(
image_inputs
)
>
0
...
...
vllm/v1/engine/processor.py
View file @
fdea8ec1
...
@@ -46,7 +46,7 @@ class Processor:
...
@@ -46,7 +46,7 @@ class Processor:
self
.
mm_input_mapper_client
=
MMInputMapperClient
(
model_config
)
self
.
mm_input_mapper_client
=
MMInputMapperClient
(
model_config
)
# Multi-modal hasher (for images)
# Multi-modal hasher (for images)
self
.
use_hash
=
model_config
.
mm_cache
_preprocessor
or
\
self
.
use_hash
=
(
not
model_config
.
disable_mm
_preprocessor
_cache
)
or
\
cache_config
.
enable_prefix_caching
cache_config
.
enable_prefix_caching
self
.
mm_hasher
=
MMHasher
()
self
.
mm_hasher
=
MMHasher
()
...
@@ -80,7 +80,7 @@ class Processor:
...
@@ -80,7 +80,7 @@ class Processor:
# Compute MM hashes (if enabled)
# Compute MM hashes (if enabled)
mm_hashes
=
None
mm_hashes
=
None
if
self
.
use_hash
:
if
self
.
use_hash
:
mm_hashes
=
self
.
mm_hasher
.
hash
(
prompt
)
mm_hashes
=
self
.
mm_hasher
.
hash
_prompt
(
prompt
)
# Process inputs.
# Process inputs.
preprocessed_inputs
=
self
.
input_preprocessor
.
preprocess
(
preprocessed_inputs
=
self
.
input_preprocessor
.
preprocess
(
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
fdea8ec1
...
@@ -19,7 +19,7 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
...
@@ -19,7 +19,7 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
LayerBlockType
,
cdiv
,
is_pin_memory_available
)
LayerBlockType
,
cdiv
,
is_pin_memory_available
)
from
vllm.v1.attention.backends.flash_attn
import
(
FlashAttentionBackend
,
from
vllm.v1.attention.backends.flash_attn
import
(
FlashAttentionBackend
,
FlashAttentionMetadata
)
FlashAttentionMetadata
)
from
vllm.v1.engine.mm_input_mapper
import
MMInputMapperClient
from
vllm.v1.engine.mm_input_mapper
import
MMHasher
,
MMInputMapperClient
from
vllm.v1.outputs
import
ModelRunnerOutput
from
vllm.v1.outputs
import
ModelRunnerOutput
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.worker.gpu_input_batch
import
CachedRequestState
,
InputBatch
from
vllm.v1.worker.gpu_input_batch
import
CachedRequestState
,
InputBatch
...
@@ -79,8 +79,14 @@ class GPUModelRunner:
...
@@ -79,8 +79,14 @@ class GPUModelRunner:
# Multi-modal data support
# Multi-modal data support
self
.
input_registry
=
INPUT_REGISTRY
self
.
input_registry
=
INPUT_REGISTRY
self
.
mm_registry
=
MULTIMODAL_REGISTRY
self
.
mm_registry
=
MULTIMODAL_REGISTRY
# NOTE: mm_input_mapper is only used for memory profiling.
self
.
mm_input_mapper
=
MMInputMapperClient
(
self
.
model_config
)
# NOTE: mm_input_mapper_client and mm_hasher are only used for memory
# profiling.
self
.
mm_input_mapper_client
=
MMInputMapperClient
(
self
.
model_config
)
self
.
mm_hasher
=
MMHasher
()
self
.
use_hash
=
(
not
model_config
.
disable_mm_preprocessor_cache
)
or
\
cache_config
.
enable_prefix_caching
self
.
max_num_encoder_input_tokens
=
self
.
scheduler_config
.
max_num_encoder_input_tokens
# noqa: E501
self
.
max_num_encoder_input_tokens
=
self
.
scheduler_config
.
max_num_encoder_input_tokens
# noqa: E501
self
.
encoder_cache_size
=
self
.
scheduler_config
.
encoder_cache_size
self
.
encoder_cache_size
=
self
.
scheduler_config
.
encoder_cache_size
...
@@ -628,9 +634,15 @@ class GPUModelRunner:
...
@@ -628,9 +634,15 @@ class GPUModelRunner:
mm_registry
=
self
.
mm_registry
,
mm_registry
=
self
.
mm_registry
,
)
)
dummy_mm_data
=
dummy_request_data
.
multi_modal_data
dummy_mm_data
=
dummy_request_data
.
multi_modal_data
dummy_mm_kwargs
,
_
=
self
.
mm_input_mapper
.
process_inputs
(
# Compute MM hashes (if enabled)
mm_hashes
=
None
if
self
.
use_hash
:
mm_hashes
=
self
.
mm_hasher
.
hash_mm_data
(
dummy_mm_data
)
dummy_mm_kwargs
=
self
.
mm_input_mapper_client
.
process_inputs
(
mm_data
=
dummy_mm_data
,
mm_data
=
dummy_mm_data
,
mm_hashes
=
None
,
mm_hashes
=
mm_hashes
,
mm_processor_kwargs
=
None
,
mm_processor_kwargs
=
None
,
precomputed_mm_inputs
=
None
)
precomputed_mm_inputs
=
None
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment