Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d9fc8cd9
Unverified
Commit
d9fc8cd9
authored
Apr 12, 2025
by
Cyrus Leung
Committed by
GitHub
Apr 12, 2025
Browse files
[V1] Enable multi-input by default (#15799)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
f069f3ea
Changes
21
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
213 additions
and
104 deletions
+213
-104
docs/source/models/supported_models.md
docs/source/models/supported_models.md
+3
-1
docs/source/serving/offline_inference.md
docs/source/serving/offline_inference.md
+24
-0
examples/offline_inference/audio_language.py
examples/offline_inference/audio_language.py
+5
-0
examples/offline_inference/encoder_decoder_multimodal.py
examples/offline_inference/encoder_decoder_multimodal.py
+5
-0
examples/offline_inference/vision_language.py
examples/offline_inference/vision_language.py
+45
-34
examples/offline_inference/vision_language_embedding.py
examples/offline_inference/vision_language_embedding.py
+7
-0
examples/offline_inference/vision_language_multi_image.py
examples/offline_inference/vision_language_multi_image.py
+5
-0
tests/entrypoints/openai/test_audio.py
tests/entrypoints/openai/test_audio.py
+34
-27
tests/models/decoder_only/vision_language/vlm_utils/core.py
tests/models/decoder_only/vision_language/vlm_utils/core.py
+4
-0
tests/models/test_oot_registration.py
tests/models/test_oot_registration.py
+1
-0
tests/multimodal/test_processing.py
tests/multimodal/test_processing.py
+5
-2
vllm/config.py
vllm/config.py
+9
-3
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+3
-3
vllm/entrypoints/chat_utils.py
vllm/entrypoints/chat_utils.py
+24
-5
vllm/model_executor/models/minicpmo.py
vllm/model_executor/models/minicpmo.py
+1
-1
vllm/model_executor/models/minicpmv.py
vllm/model_executor/models/minicpmv.py
+2
-2
vllm/model_executor/models/qwen2_vl.py
vllm/model_executor/models/qwen2_vl.py
+2
-2
vllm/multimodal/processing.py
vllm/multimodal/processing.py
+29
-5
vllm/multimodal/profiling.py
vllm/multimodal/profiling.py
+1
-17
vllm/multimodal/registry.py
vllm/multimodal/registry.py
+4
-2
No files found.
docs/source/models/supported_models.md
View file @
d9fc8cd9
...
...
@@ -759,7 +759,7 @@ On the other hand, modalities separated by `/` are mutually exclusive.
See
[
this page
](
#multimodal-inputs
)
on how to pass multi-modal inputs to the model.
:::{important}
To enable multiple multi-modal items per text prompt, you have to set
`limit_mm_per_prompt`
(offline inference)
**
To enable multiple multi-modal items per text prompt
in vLLM V0**
, you have to set
`limit_mm_per_prompt`
(offline inference)
or
`--limit-mm-per-prompt`
(online serving). For example, to enable passing up to 4 images per text prompt:
Offline inference:
...
...
@@ -777,6 +777,8 @@ Online serving:
vllm serve Qwen/Qwen2-VL-7B-Instruct
--limit-mm-per-prompt
image
=
4
```
**This is no longer required if you are using vLLM V1.**
:::
:::{note}
...
...
docs/source/serving/offline_inference.md
View file @
d9fc8cd9
...
...
@@ -110,6 +110,30 @@ If you run out of CPU RAM, try the following options:
-
(Multi-modal models only) you can set the size of multi-modal input cache using
`VLLM_MM_INPUT_CACHE_GIB`
environment variable (default 4 GiB).
-
(CPU backend only) you can set the size of KV cache using
`VLLM_CPU_KVCACHE_SPACE`
environment variable (default 4 GiB).
#### Disable unused modalities
You can disable unused modalities (except for text) by setting its limit to zero.
For example, if your application only accepts image input, there is no need to allocate any memory for videos.
```
python
from
vllm
import
LLM
# Accept images but not videos
llm
=
LLM
(
model
=
"Qwen/Qwen2.5-VL-3B-Instruct"
,
limit_mm_per_prompt
=
{
"video"
:
0
})
```
You can even run a multi-modal model for text-only inference:
```
python
from
vllm
import
LLM
# Don't accept images. Just text.
llm
=
LLM
(
model
=
"google/gemma-3-27b-it"
,
limit_mm_per_prompt
=
{
"image"
:
0
})
```
### Performance optimization and tuning
You can potentially improve the performance of vLLM by finetuning various options.
...
...
examples/offline_inference/audio_language.py
View file @
d9fc8cd9
...
...
@@ -196,6 +196,11 @@ def main(args):
req_data
=
model_example_map
[
model
](
question_per_audio_count
[
audio_count
],
audio_count
)
# Disable other modalities to save memory
default_limits
=
{
"image"
:
0
,
"video"
:
0
,
"audio"
:
0
}
req_data
.
engine_args
.
limit_mm_per_prompt
=
default_limits
|
dict
(
req_data
.
engine_args
.
limit_mm_per_prompt
or
{})
engine_args
=
asdict
(
req_data
.
engine_args
)
|
{
"seed"
:
args
.
seed
}
llm
=
LLM
(
**
engine_args
)
...
...
examples/offline_inference/encoder_decoder_multimodal.py
View file @
d9fc8cd9
...
...
@@ -133,6 +133,11 @@ def main(args):
req_data
=
model_example_map
[
model
]()
# Disable other modalities to save memory
default_limits
=
{
"image"
:
0
,
"video"
:
0
,
"audio"
:
0
}
req_data
.
engine_args
.
limit_mm_per_prompt
=
default_limits
|
dict
(
req_data
.
engine_args
.
limit_mm_per_prompt
or
{})
engine_args
=
asdict
(
req_data
.
engine_args
)
|
{
"seed"
:
args
.
seed
}
llm
=
LLM
(
**
engine_args
)
...
...
examples/offline_inference/vision_language.py
View file @
d9fc8cd9
...
...
@@ -45,7 +45,7 @@ def run_aria(questions: list[str], modality: str) -> ModelRequestData:
max_model_len
=
4096
,
max_num_seqs
=
2
,
dtype
=
"bfloat16"
,
disable_mm_preprocessor_cache
=
args
.
disable_mm_preprocessor_cache
,
limit_mm_per_prompt
=
{
"image"
:
1
}
,
)
prompts
=
[(
f
"<|im_start|>user
\n
<fim_prefix><|img|><fim_suffix>
{
question
}
"
...
...
@@ -71,7 +71,7 @@ def run_aya_vision(questions: list[str], modality: str) -> ModelRequestData:
max_model_len
=
2048
,
max_num_seqs
=
2
,
mm_processor_kwargs
=
{
"crop_to_patches"
:
True
},
disable_mm_preprocessor_cache
=
args
.
disable_mm_preprocessor_cache
,
limit_mm_per_prompt
=
{
"image"
:
1
}
,
)
prompts
=
[
f
"<|START_OF_TURN_TOKEN|><|USER_TOKEN|><image>
{
question
}
<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>"
...
...
@@ -92,7 +92,7 @@ def run_blip2(questions: list[str], modality: str) -> ModelRequestData:
prompts
=
[
f
"Question:
{
question
}
Answer:"
for
question
in
questions
]
engine_args
=
EngineArgs
(
model
=
"Salesforce/blip2-opt-6.7b"
,
disable_mm_preprocessor_cache
=
args
.
disable_mm_preprocessor_cache
,
limit_mm_per_prompt
=
{
"image"
:
1
}
,
)
return
ModelRequestData
(
...
...
@@ -110,7 +110,7 @@ def run_chameleon(questions: list[str], modality: str) -> ModelRequestData:
model
=
"facebook/chameleon-7b"
,
max_model_len
=
4096
,
max_num_seqs
=
2
,
disable_mm_preprocessor_cache
=
args
.
disable_mm_preprocessor_cache
,
limit_mm_per_prompt
=
{
"image"
:
1
}
,
)
return
ModelRequestData
(
...
...
@@ -129,8 +129,8 @@ def run_deepseek_vl2(questions: list[str], modality: str) -> ModelRequestData:
model
=
model_name
,
max_model_len
=
4096
,
max_num_seqs
=
2
,
disable_mm_preprocessor_cache
=
args
.
disable_mm_preprocessor_cache
,
hf_overrides
=
{
"architectures"
:
[
"DeepseekVLV2ForCausalLM"
]},
limit_mm_per_prompt
=
{
"image"
:
1
},
)
prompts
=
[
...
...
@@ -155,7 +155,7 @@ def run_florence2(questions: list[str], modality: str) -> ModelRequestData:
max_num_seqs
=
2
,
trust_remote_code
=
True
,
dtype
=
"bfloat16"
,
disable_mm_preprocessor_cache
=
args
.
disable_mm_preprocessor_cache
,
limit_mm_per_prompt
=
{
"image"
:
1
}
,
)
prompts
=
[
"<MORE_DETAILED_CAPTION>"
for
_
in
questions
]
...
...
@@ -175,7 +175,7 @@ def run_fuyu(questions: list[str], modality: str) -> ModelRequestData:
model
=
"adept/fuyu-8b"
,
max_model_len
=
2048
,
max_num_seqs
=
2
,
disable_mm_preprocessor_cache
=
args
.
disable_mm_preprocessor_cache
,
limit_mm_per_prompt
=
{
"image"
:
1
}
,
)
return
ModelRequestData
(
...
...
@@ -194,7 +194,7 @@ def run_gemma3(questions: list[str], modality: str) -> ModelRequestData:
max_model_len
=
2048
,
max_num_seqs
=
2
,
mm_processor_kwargs
=
{
"do_pan_and_scan"
:
True
},
disable_mm_preprocessor_cache
=
args
.
disable_mm_preprocessor_cache
,
limit_mm_per_prompt
=
{
"image"
:
1
}
,
)
prompts
=
[(
"<bos><start_of_turn>user
\n
"
...
...
@@ -219,7 +219,7 @@ def run_glm4v(questions: list[str], modality: str) -> ModelRequestData:
trust_remote_code
=
True
,
enforce_eager
=
True
,
hf_overrides
=
{
"architectures"
:
[
"GLM4VForCausalLM"
]},
disable_mm_preprocessor_cache
=
args
.
disable_mm_preprocessor_cache
,
limit_mm_per_prompt
=
{
"image"
:
1
}
,
)
prompts
=
[
...
...
@@ -246,7 +246,7 @@ def run_h2ovl(questions: list[str], modality: str) -> ModelRequestData:
model
=
model_name
,
trust_remote_code
=
True
,
max_model_len
=
8192
,
disable_mm_preprocessor_cache
=
args
.
disable_mm_preprocessor_cache
,
limit_mm_per_prompt
=
{
"image"
:
1
}
,
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
,
...
...
@@ -287,7 +287,7 @@ def run_idefics3(questions: list[str], modality: str) -> ModelRequestData:
"longest_edge"
:
3
*
364
},
},
disable_mm_preprocessor_cache
=
args
.
disable_mm_preprocessor_cache
,
limit_mm_per_prompt
=
{
"image"
:
1
}
,
)
prompts
=
[(
f
"<|begin_of_text|>User:<image>
{
question
}
<end_of_utterance>
\n
Assistant:"
...
...
@@ -314,7 +314,7 @@ def run_smolvlm(questions: list[str], modality: str) -> ModelRequestData:
"longest_edge"
:
384
},
},
disable_mm_preprocessor_cache
=
args
.
disable_mm_preprocessor_cache
,
limit_mm_per_prompt
=
{
"image"
:
1
}
,
)
prompts
=
[
(
f
"<|im_start|>User:<image>
{
question
}
<end_of_utterance>
\n
Assistant:"
)
...
...
@@ -337,7 +337,7 @@ def run_internvl(questions: list[str], modality: str) -> ModelRequestData:
model
=
model_name
,
trust_remote_code
=
True
,
max_model_len
=
4096
,
disable_mm_preprocessor_cache
=
args
.
disable_mm_preprocessor_cache
,
limit_mm_per_prompt
=
{
"image"
:
1
}
,
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
,
...
...
@@ -375,7 +375,7 @@ def run_llava(questions: list[str], modality: str) -> ModelRequestData:
engine_args
=
EngineArgs
(
model
=
"llava-hf/llava-1.5-7b-hf"
,
max_model_len
=
4096
,
disable_mm_preprocessor_cache
=
args
.
disable_mm_preprocessor_cache
,
limit_mm_per_prompt
=
{
"image"
:
1
}
,
)
return
ModelRequestData
(
...
...
@@ -392,7 +392,7 @@ def run_llava_next(questions: list[str], modality: str) -> ModelRequestData:
engine_args
=
EngineArgs
(
model
=
"llava-hf/llava-v1.6-mistral-7b-hf"
,
max_model_len
=
8192
,
disable_mm_preprocessor_cache
=
args
.
disable_mm_preprocessor_cache
,
limit_mm_per_prompt
=
{
"image"
:
1
}
,
)
return
ModelRequestData
(
...
...
@@ -414,7 +414,7 @@ def run_llava_next_video(questions: list[str],
model
=
"llava-hf/LLaVA-NeXT-Video-7B-hf"
,
max_model_len
=
8192
,
max_num_seqs
=
2
,
disable_mm_preprocessor_cache
=
args
.
disable_mm_preprocessor_cache
,
limit_mm_per_prompt
=
{
"image"
:
1
}
,
)
return
ModelRequestData
(
...
...
@@ -442,7 +442,7 @@ def run_llava_onevision(questions: list[str],
engine_args
=
EngineArgs
(
model
=
"llava-hf/llava-onevision-qwen2-7b-ov-hf"
,
max_model_len
=
16384
,
disable_mm_preprocessor_cache
=
args
.
disable_mm_preprocessor_cache
,
limit_mm_per_prompt
=
{
"image"
:
1
}
,
)
return
ModelRequestData
(
...
...
@@ -465,7 +465,7 @@ def run_mantis(questions: list[str], modality: str) -> ModelRequestData:
model
=
"TIGER-Lab/Mantis-8B-siglip-llama3"
,
max_model_len
=
4096
,
hf_overrides
=
{
"architectures"
:
[
"MantisForConditionalGeneration"
]},
disable_mm_preprocessor_cache
=
args
.
disable_mm_preprocessor_cache
,
limit_mm_per_prompt
=
{
"image"
:
1
}
,
)
stop_token_ids
=
[
128009
]
...
...
@@ -506,7 +506,7 @@ def run_minicpmv_base(questions: list[str], modality: str, model_name):
max_model_len
=
4096
,
max_num_seqs
=
2
,
trust_remote_code
=
True
,
disable_mm_preprocessor_cache
=
args
.
disable_mm_preprocessor_cache
,
limit_mm_per_prompt
=
{
"image"
:
1
}
,
)
# NOTE The stop_token_ids are different for various versions of MiniCPM-V
# 2.0
...
...
@@ -561,7 +561,7 @@ def run_mistral3(questions: list[str], modality: str) -> ModelRequestData:
max_model_len
=
8192
,
max_num_seqs
=
2
,
tensor_parallel_size
=
2
,
disable_mm_preprocessor_cache
=
args
.
disable_mm_preprocessor_cache
,
limit_mm_per_prompt
=
{
"image"
:
1
}
,
)
prompts
=
[
f
"<s>[INST]
{
question
}
\n
[IMG][/INST]"
for
question
in
questions
]
...
...
@@ -587,7 +587,7 @@ def run_mllama(questions: list[str], modality: str) -> ModelRequestData:
model
=
model_name
,
max_model_len
=
8192
,
max_num_seqs
=
2
,
disable_mm_preprocessor_cache
=
args
.
disable_mm_preprocessor_cache
,
limit_mm_per_prompt
=
{
"image"
:
1
}
,
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
)
...
...
@@ -611,7 +611,7 @@ def run_mllama(questions: list[str], modality: str) -> ModelRequestData:
)
def
run_llama4
(
questions
:
list
[
str
],
modality
:
str
):
def
run_llama4
(
questions
:
list
[
str
],
modality
:
str
)
->
ModelRequestData
:
assert
modality
==
"image"
model_name
=
"meta-llama/Llama-4-Scout-17B-16E-Instruct"
...
...
@@ -621,8 +621,8 @@ def run_llama4(questions: list[str], modality: str):
max_model_len
=
8192
,
max_num_seqs
=
4
,
tensor_parallel_size
=
8
,
disable_mm_preprocessor_cache
=
args
.
disable_mm_preprocessor_cache
,
gpu_memory_utilization
=
0.4
,
limit_mm_per_prompt
=
{
"image"
:
1
},
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
)
...
...
@@ -657,7 +657,7 @@ def run_molmo(questions: list[str], modality: str) -> ModelRequestData:
model
=
model_name
,
trust_remote_code
=
True
,
dtype
=
"bfloat16"
,
disable_mm_preprocessor_cache
=
args
.
disable_mm_preprocessor_cache
,
limit_mm_per_prompt
=
{
"image"
:
1
}
,
)
prompts
=
[
...
...
@@ -683,7 +683,7 @@ def run_nvlm_d(questions: list[str], modality: str) -> ModelRequestData:
trust_remote_code
=
True
,
max_model_len
=
4096
,
tensor_parallel_size
=
4
,
disable_mm_preprocessor_cache
=
args
.
disable_mm_preprocessor_cache
,
limit_mm_per_prompt
=
{
"image"
:
1
}
,
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
,
...
...
@@ -710,7 +710,8 @@ def run_paligemma(questions: list[str], modality: str) -> ModelRequestData:
prompts
=
[
"caption en"
for
_
in
questions
]
engine_args
=
EngineArgs
(
model
=
"google/paligemma-3b-mix-224"
,
disable_mm_preprocessor_cache
=
args
.
disable_mm_preprocessor_cache
)
limit_mm_per_prompt
=
{
"image"
:
1
},
)
return
ModelRequestData
(
engine_args
=
engine_args
,
...
...
@@ -726,7 +727,8 @@ def run_paligemma2(questions: list[str], modality: str) -> ModelRequestData:
prompts
=
[
"caption en"
for
_
in
questions
]
engine_args
=
EngineArgs
(
model
=
"google/paligemma2-3b-ft-docci-448"
,
disable_mm_preprocessor_cache
=
args
.
disable_mm_preprocessor_cache
)
limit_mm_per_prompt
=
{
"image"
:
1
},
)
return
ModelRequestData
(
engine_args
=
engine_args
,
...
...
@@ -762,7 +764,7 @@ def run_phi3v(questions: list[str], modality: str) -> ModelRequestData:
max_num_seqs
=
2
,
# Note - mm_processor_kwargs can also be passed to generate/chat calls
mm_processor_kwargs
=
{
"num_crops"
:
16
},
disable_mm_preprocessor_cache
=
args
.
disable_mm_preprocessor_cache
,
limit_mm_per_prompt
=
{
"image"
:
1
}
,
)
return
ModelRequestData
(
...
...
@@ -793,6 +795,7 @@ def run_phi4mm(questions: list[str], modality: str) -> ModelRequestData:
max_num_seqs
=
2
,
enable_lora
=
True
,
max_lora_rank
=
320
,
limit_mm_per_prompt
=
{
"image"
:
1
},
)
return
ModelRequestData
(
...
...
@@ -813,7 +816,7 @@ def run_pixtral_hf(questions: list[str], modality: str) -> ModelRequestData:
model
=
model_name
,
max_model_len
=
6144
,
max_num_seqs
=
2
,
disable_mm_preprocessor_cache
=
args
.
disable_mm_preprocessor_cache
,
limit_mm_per_prompt
=
{
"image"
:
1
}
,
)
prompts
=
[
f
"<s>[INST]
{
question
}
\n
[IMG][/INST]"
for
question
in
questions
]
...
...
@@ -834,7 +837,7 @@ def run_qwen_vl(questions: list[str], modality: str) -> ModelRequestData:
max_model_len
=
1024
,
max_num_seqs
=
2
,
hf_overrides
=
{
"architectures"
:
[
"QwenVLForConditionalGeneration"
]},
disable_mm_preprocessor_cache
=
args
.
disable_mm_preprocessor_cache
,
limit_mm_per_prompt
=
{
"image"
:
1
}
,
)
prompts
=
[
f
"
{
question
}
Picture 1: <img></img>
\n
"
for
question
in
questions
]
...
...
@@ -859,7 +862,7 @@ def run_qwen2_vl(questions: list[str], modality: str) -> ModelRequestData:
"min_pixels"
:
28
*
28
,
"max_pixels"
:
1280
*
28
*
28
,
},
disable_mm_preprocessor_cache
=
args
.
disable_mm_preprocessor_cache
,
limit_mm_per_prompt
=
{
"image"
:
1
}
,
)
if
modality
==
"image"
:
...
...
@@ -894,7 +897,7 @@ def run_qwen2_5_vl(questions: list[str], modality: str) -> ModelRequestData:
"max_pixels"
:
1280
*
28
*
28
,
"fps"
:
1
,
},
disable_mm_preprocessor_cache
=
args
.
disable_mm_preprocessor_cache
,
limit_mm_per_prompt
=
{
"image"
:
1
}
,
)
if
modality
==
"image"
:
...
...
@@ -925,7 +928,7 @@ def run_skyworkr1v(questions: list[str], modality: str) -> ModelRequestData:
model
=
model_name
,
trust_remote_code
=
True
,
max_model_len
=
4096
,
disable_mm_preprocessor_cache
=
args
.
disable_mm_preprocessor_cache
,
limit_mm_per_prompt
=
{
"image"
:
1
}
,
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
,
...
...
@@ -1082,7 +1085,15 @@ def main(args):
req_data
=
model_example_map
[
model
](
questions
,
modality
)
engine_args
=
asdict
(
req_data
.
engine_args
)
|
{
"seed"
:
args
.
seed
}
# Disable other modalities to save memory
default_limits
=
{
"image"
:
0
,
"video"
:
0
,
"audio"
:
0
}
req_data
.
engine_args
.
limit_mm_per_prompt
=
default_limits
|
dict
(
req_data
.
engine_args
.
limit_mm_per_prompt
or
{})
engine_args
=
asdict
(
req_data
.
engine_args
)
|
{
"seed"
:
args
.
seed
,
"disable_mm_preprocessor_cache"
:
args
.
disable_mm_preprocessor_cache
,
}
llm
=
LLM
(
**
engine_args
)
# To maintain code compatibility in this script, we add LoRA here.
...
...
examples/offline_inference/vision_language_embedding.py
View file @
d9fc8cd9
...
...
@@ -63,6 +63,7 @@ def run_e5_v(query: Query) -> ModelRequestData:
model
=
"royokong/e5-v"
,
task
=
"embed"
,
max_model_len
=
4096
,
limit_mm_per_prompt
=
{
"image"
:
1
},
)
return
ModelRequestData
(
...
...
@@ -93,6 +94,7 @@ def run_vlm2vec(query: Query) -> ModelRequestData:
task
=
"embed"
,
trust_remote_code
=
True
,
mm_processor_kwargs
=
{
"num_crops"
:
4
},
limit_mm_per_prompt
=
{
"image"
:
1
},
)
return
ModelRequestData
(
...
...
@@ -131,6 +133,11 @@ def run_encode(model: str, modality: QueryModality, seed: Optional[int]):
query
=
get_query
(
modality
)
req_data
=
model_example_map
[
model
](
query
)
# Disable other modalities to save memory
default_limits
=
{
"image"
:
0
,
"video"
:
0
,
"audio"
:
0
}
req_data
.
engine_args
.
limit_mm_per_prompt
=
default_limits
|
dict
(
req_data
.
engine_args
.
limit_mm_per_prompt
or
{})
engine_args
=
asdict
(
req_data
.
engine_args
)
|
{
"seed"
:
seed
}
llm
=
LLM
(
**
engine_args
)
...
...
examples/offline_inference/vision_language_multi_image.py
View file @
d9fc8cd9
...
...
@@ -687,6 +687,11 @@ def run_chat(model: str, question: str, image_urls: list[str],
seed
:
Optional
[
int
]):
req_data
=
model_example_map
[
model
](
question
,
image_urls
)
# Disable other modalities to save memory
default_limits
=
{
"image"
:
0
,
"video"
:
0
,
"audio"
:
0
}
req_data
.
engine_args
.
limit_mm_per_prompt
=
default_limits
|
dict
(
req_data
.
engine_args
.
limit_mm_per_prompt
or
{})
engine_args
=
asdict
(
req_data
.
engine_args
)
|
{
"seed"
:
seed
}
llm
=
LLM
(
**
engine_args
)
...
...
tests/entrypoints/openai/test_audio.py
View file @
d9fc8cd9
...
...
@@ -12,7 +12,9 @@ from ...utils import RemoteOpenAIServer
MODEL_NAME
=
"fixie-ai/ultravox-v0_5-llama-3_2-1b"
TEST_AUDIO_URLS
=
[
AudioAsset
(
"winning_call"
).
url
,
AudioAsset
(
"mary_had_lamb"
).
url
,
]
MAXIMUM_AUDIOS
=
2
@
pytest
.
fixture
(
scope
=
"module"
)
...
...
@@ -24,6 +26,8 @@ def server():
"5"
,
"--enforce-eager"
,
"--trust-remote-code"
,
"--limit-mm-per-prompt"
,
f
"audio=
{
MAXIMUM_AUDIOS
}
"
,
]
with
RemoteOpenAIServer
(
MODEL_NAME
,
args
)
as
remote_server
:
...
...
@@ -46,7 +50,7 @@ def base64_encoded_audio() -> dict[str, str]:
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
MODEL_NAME
])
@
pytest
.
mark
.
parametrize
(
"audio_url"
,
TEST_AUDIO_URLS
)
@
pytest
.
mark
.
parametrize
(
"audio_url"
,
[
TEST_AUDIO_URLS
[
0
]]
)
async
def
test_single_chat_session_audio
(
client
:
openai
.
AsyncOpenAI
,
model_name
:
str
,
audio_url
:
str
):
messages
=
[{
...
...
@@ -100,7 +104,7 @@ async def test_single_chat_session_audio(client: openai.AsyncOpenAI,
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
MODEL_NAME
])
@
pytest
.
mark
.
parametrize
(
"audio_url"
,
TEST_AUDIO_URLS
)
@
pytest
.
mark
.
parametrize
(
"audio_url"
,
[
TEST_AUDIO_URLS
[
0
]]
)
async
def
test_single_chat_session_audio_base64encoded
(
client
:
openai
.
AsyncOpenAI
,
model_name
:
str
,
audio_url
:
str
,
base64_encoded_audio
:
dict
[
str
,
str
]):
...
...
@@ -158,7 +162,7 @@ async def test_single_chat_session_audio_base64encoded(
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
MODEL_NAME
])
@
pytest
.
mark
.
parametrize
(
"audio_url"
,
TEST_AUDIO_URLS
)
@
pytest
.
mark
.
parametrize
(
"audio_url"
,
[
TEST_AUDIO_URLS
[
0
]]
)
async
def
test_single_chat_session_input_audio
(
client
:
openai
.
AsyncOpenAI
,
model_name
:
str
,
audio_url
:
str
,
base64_encoded_audio
:
dict
[
str
,
str
]):
...
...
@@ -330,28 +334,21 @@ async def test_chat_streaming_input_audio(client: openai.AsyncOpenAI,
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
MODEL_NAME
])
@
pytest
.
mark
.
parametrize
(
"audio_url"
,
TEST_AUDIO_URLS
)
@
pytest
.
mark
.
parametrize
(
"audio_urls"
,
[
TEST_AUDIO_URLS
,
TEST_AUDIO_URLS
+
[
TEST_AUDIO_URLS
[
0
]]])
async
def
test_multi_audio_input
(
client
:
openai
.
AsyncOpenAI
,
model_name
:
str
,
audio_url
:
str
,
base64_encoded_audio
:
dict
[
str
,
str
]):
audio_urls
:
list
[
str
]):
messages
=
[{
"role"
:
"user"
,
"content"
:
[
{
*
(
{
"type"
:
"audio_url"
,
"audio_url"
:
{
"url"
:
audio_url
}
},
{
"type"
:
"input_audio"
,
"input_audio"
:
{
"data"
:
base64_encoded_audio
[
audio_url
],
"format"
:
"wav"
}
},
}
for
audio_url
in
audio_urls
),
{
"type"
:
"text"
,
"text"
:
"What's happening in this audio?"
...
...
@@ -359,20 +356,30 @@ async def test_multi_audio_input(client: openai.AsyncOpenAI, model_name: str,
],
}]
with
pytest
.
raises
(
openai
.
BadRequestError
):
# test multi-audio input
await
client
.
chat
.
completions
.
create
(
if
len
(
audio_urls
)
>
MAXIMUM_AUDIOS
:
with
pytest
.
raises
(
openai
.
BadRequestError
):
# test multi-audio input
await
client
.
chat
.
completions
.
create
(
model
=
model_name
,
messages
=
messages
,
max_completion_tokens
=
10
,
temperature
=
0.0
,
)
# the server should still work afterwards
completion
=
await
client
.
completions
.
create
(
model
=
model_name
,
prompt
=
[
0
,
0
,
0
,
0
,
0
],
max_tokens
=
5
,
temperature
=
0.0
,
)
completion
=
completion
.
choices
[
0
].
text
assert
completion
is
not
None
and
len
(
completion
)
>=
0
else
:
chat_completion
=
await
client
.
chat
.
completions
.
create
(
model
=
model_name
,
messages
=
messages
,
max_completion_tokens
=
10
,
temperature
=
0.0
,
)
# the server should still work afterwards
completion
=
await
client
.
completions
.
create
(
model
=
model_name
,
prompt
=
[
0
,
0
,
0
,
0
,
0
],
max_tokens
=
5
,
temperature
=
0.0
,
)
completion
=
completion
.
choices
[
0
].
text
assert
completion
is
not
None
and
len
(
completion
)
>=
0
message
=
chat_completion
.
choices
[
0
].
message
assert
message
.
content
is
not
None
and
len
(
message
.
content
)
>=
0
tests/models/decoder_only/vision_language/vlm_utils/core.py
View file @
d9fc8cd9
...
...
@@ -51,6 +51,10 @@ def run_test(
model_info
.
check_available_online
(
on_fail
=
"skip"
)
model_info
.
check_transformers_version
(
on_fail
=
"skip"
)
# Disable other modalities to save memory
default_limits
=
{
"image"
:
0
,
"video"
:
0
,
"audio"
:
0
}
limit_mm_per_prompt
=
default_limits
|
limit_mm_per_prompt
vllm_outputs_per_mm
=
[]
hf_outputs_per_mm
=
[]
...
...
tests/models/test_oot_registration.py
View file @
d9fc8cd9
...
...
@@ -90,6 +90,7 @@ def test_oot_registration_multimodal(
max_model_len
=
4096
,
enforce_eager
=
True
,
limit_mm_per_prompt
=
{
"image"
:
1
})
first_token
=
llm
.
get_tokenizer
().
decode
(
0
)
outputs
=
llm
.
generate
(
prompts
,
sampling_params
)
...
...
tests/multimodal/test_processing.py
View file @
d9fc8cd9
...
...
@@ -972,10 +972,13 @@ def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid):
if
is_valid
:
exc_ctx
=
nullcontext
()
else
:
exc_ctx
=
pytest
.
raises
(
ValueError
,
match
=
"
this
model only supports"
)
exc_ctx
=
pytest
.
raises
(
ValueError
,
match
=
"
The
model only supports"
)
with
exc_ctx
:
profiler
.
get_decoder_dummy_data
(
model_config
.
max_model_len
)
profiler
.
get_decoder_dummy_data
(
model_config
.
max_model_len
,
mm_counts
=
limit_mm_per_prompt
,
)
@
pytest
.
mark
.
parametrize
(
"model_id"
,
[
"llava-hf/llava-v1.6-mistral-7b-hf"
])
...
...
vllm/config.py
View file @
d9fc8cd9
...
...
@@ -2667,14 +2667,20 @@ class MultiModalConfig:
usedforsecurity
=
False
).
hexdigest
()
return
hash_str
def
get_default_limit_per_prompt
(
self
)
->
int
:
"""
Return the default number of input items allowed per prompt
for any modality if not specified by the user.
"""
return
999
if
envs
.
VLLM_USE_V1
else
1
def
get_limit_per_prompt
(
self
,
modality
:
str
)
->
int
:
"""
Get the maximum number of input items allowed per prompt
for the given modality.
If not set by the user, this defaults to `1`.
"""
return
self
.
limit_per_prompt
.
get
(
modality
,
1
)
default
=
self
.
get_default_limit_per_prompt
()
return
self
.
limit_per_prompt
.
get
(
modality
,
default
)
# TODO: Add configs to init vision tower or not.
...
...
vllm/engine/arg_utils.py
View file @
d9fc8cd9
...
...
@@ -671,13 +671,13 @@ class EngineArgs:
type
=
nullable_kvs
,
default
=
EngineArgs
.
limit_mm_per_prompt
,
# The default value is given in
# MultiModalConfig.get_limit_per_prompt
# MultiModalConfig.get_
default_
limit_per_prompt
help
=
(
'For each multimodal plugin, limit how many '
'input instances to allow for each prompt. '
'Expects a comma-separated list of items, '
'e.g.: `image=16,video=2` allows a maximum of 16 '
'images and 2 videos per prompt. Defaults to
1 for
'
'each modality.'
))
'images and 2 videos per prompt. Defaults to '
'
1 (V0) or 999 (V1) for
each modality.'
))
parser
.
add_argument
(
'--mm-processor-kwargs'
,
default
=
None
,
...
...
vllm/entrypoints/chat_utils.py
View file @
d9fc8cd9
...
...
@@ -35,7 +35,7 @@ from typing_extensions import Required, TypeAlias, TypedDict
from
vllm.config
import
ModelConfig
from
vllm.logger
import
init_logger
from
vllm.multimodal
import
MultiModalDataDict
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalDataDict
from
vllm.multimodal.utils
import
MediaConnector
from
vllm.transformers_utils.processor
import
cached_get_processor
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
,
MistralTokenizer
...
...
@@ -452,8 +452,6 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
self
.
_model_config
=
model_config
self
.
_tokenizer
=
tokenizer
self
.
_allowed_items
=
(
model_config
.
multimodal_config
.
limit_per_prompt
if
model_config
.
multimodal_config
else
{})
self
.
_items_by_modality
=
defaultdict
[
str
,
list
[
_T
]](
list
)
...
...
@@ -465,6 +463,10 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
def
allowed_local_media_path
(
self
):
return
self
.
_model_config
.
allowed_local_media_path
@
property
def
mm_registry
(
self
):
return
MULTIMODAL_REGISTRY
@
staticmethod
@
cache
def
_cached_token_str
(
tokenizer
:
AnyTokenizer
,
token_index
:
int
)
->
str
:
...
...
@@ -540,12 +542,29 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
Add a multi-modal item to the current prompt and returns the
placeholder string to use, if any.
"""
allowed_count
=
self
.
_allowed_items
.
get
(
modality
,
1
)
mm_registry
=
self
.
mm_registry
model_config
=
self
.
model_config
input_modality
=
modality
.
replace
(
"_embeds"
,
""
)
if
mm_registry
.
has_processor
(
model_config
):
mm_processor
=
mm_registry
.
create_processor
(
model_config
)
allowed_counts
=
mm_processor
.
info
.
get_allowed_mm_limits
()
allowed_count
=
allowed_counts
.
get
(
input_modality
,
0
)
else
:
mm_config
=
model_config
.
multimodal_config
if
mm_config
is
None
:
msg
=
"This model does not support multi-modal inputs"
raise
ValueError
(
msg
)
allowed_count
=
mm_config
.
get_limit_per_prompt
(
input_modality
)
current_count
=
len
(
self
.
_items_by_modality
[
modality
])
+
1
if
current_count
>
allowed_count
:
raise
ValueError
(
f
"At most
{
allowed_count
}
{
modality
}
(s) may be provided in "
"one request."
)
"one request. You can set `--limit-mm-per-prompt` to "
"increase this limit if the model supports it."
)
self
.
_items_by_modality
[
modality
].
append
(
item
)
...
...
vllm/model_executor/models/minicpmo.py
View file @
d9fc8cd9
...
...
@@ -126,7 +126,7 @@ class MiniCPMOMultiModalDataParser(MiniCPMVMultiModalDataParser):
def
_parse_audio_data
(
self
,
data
:
Union
[
dict
[
str
,
torch
.
Tensor
],
ModalityData
[
AudioItem
]],
)
->
ModalityDataItems
[
Any
,
Any
]:
)
->
Optional
[
ModalityDataItems
[
Any
,
Any
]
]
:
if
isinstance
(
data
,
dict
):
return
MiniCPMOAudioEmbeddingItems
(
data
,
...
...
vllm/model_executor/models/minicpmv.py
View file @
d9fc8cd9
...
...
@@ -290,7 +290,7 @@ class MiniCPMVMultiModalDataParser(MultiModalDataParser):
def
_parse_image_data
(
self
,
data
:
Union
[
dict
[
str
,
torch
.
Tensor
],
ModalityData
[
ImageItem
]],
)
->
ModalityDataItems
[
Any
,
Any
]:
)
->
Optional
[
ModalityDataItems
[
Any
,
Any
]
]
:
if
isinstance
(
data
,
dict
):
return
MiniCPMVImageEmbeddingItems
(
data
,
...
...
@@ -302,7 +302,7 @@ class MiniCPMVMultiModalDataParser(MultiModalDataParser):
def
_parse_video_data
(
self
,
data
:
Union
[
dict
[
str
,
torch
.
Tensor
],
ModalityData
[
VideoItem
]],
)
->
ModalityDataItems
[
Any
,
Any
]:
)
->
Optional
[
ModalityDataItems
[
Any
,
Any
]
]
:
if
isinstance
(
data
,
dict
):
return
MiniCPMVVideoEmbeddingItems
(
data
,
...
...
vllm/model_executor/models/qwen2_vl.py
View file @
d9fc8cd9
...
...
@@ -720,7 +720,7 @@ class Qwen2VLMultiModalDataParser(MultiModalDataParser):
def
_parse_image_data
(
self
,
data
:
Union
[
dict
[
str
,
torch
.
Tensor
],
ModalityData
[
ImageItem
]],
)
->
ModalityDataItems
[
Any
,
Any
]:
)
->
Optional
[
ModalityDataItems
[
Any
,
Any
]
]
:
if
isinstance
(
data
,
dict
):
return
DictEmbeddingItems
(
data
,
...
...
@@ -734,7 +734,7 @@ class Qwen2VLMultiModalDataParser(MultiModalDataParser):
def
_parse_video_data
(
self
,
data
:
Union
[
dict
[
str
,
torch
.
Tensor
],
ModalityData
[
VideoItem
]],
)
->
ModalityDataItems
[
Any
,
Any
]:
)
->
Optional
[
ModalityDataItems
[
Any
,
Any
]
]
:
if
isinstance
(
data
,
dict
):
return
DictEmbeddingItems
(
data
,
...
...
vllm/multimodal/processing.py
View file @
d9fc8cd9
...
...
@@ -1034,6 +1034,20 @@ class BaseProcessingInfo:
"""
raise
NotImplementedError
def
get_allowed_mm_limits
(
self
)
->
Mapping
[
str
,
int
]:
"""Return the maximum allowed number of items for each modality."""
supported_mm_limits
=
self
.
get_supported_mm_limits
()
mm_config
=
self
.
ctx
.
get_mm_config
()
allowed_limits
=
dict
[
str
,
int
]()
for
modality
,
supported_limit
in
supported_mm_limits
.
items
():
user_limit
=
mm_config
.
get_limit_per_prompt
(
modality
)
allowed_limits
[
modality
]
=
(
user_limit
if
supported_limit
is
None
else
min
(
user_limit
,
supported_limit
))
return
allowed_limits
_I
=
TypeVar
(
"_I"
,
bound
=
BaseProcessingInfo
)
...
...
@@ -1087,14 +1101,24 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
before passing them to :meth:`_get_hf_mm_data`.
"""
mm_items
=
self
.
data_parser
.
parse_mm_data
(
mm_data
)
mm_config
=
self
.
info
.
ctx
.
get_mm_config
()
supported_mm_limits
=
self
.
info
.
get_supported_mm_limits
()
allowed_mm_limits
=
self
.
info
.
get_allowed_mm_limits
()
for
modality
,
items
in
mm_items
.
items
():
limit
=
mm_config
.
get_limit_per_prompt
(
modality
)
if
len
(
items
)
>
limit
:
supported_limit
=
supported_mm_limits
.
get
(
modality
,
0
)
allowed_limit
=
allowed_mm_limits
.
get
(
modality
,
0
)
num_items
=
len
(
items
)
if
supported_limit
is
not
None
and
num_items
>
supported_limit
:
raise
ValueError
(
f
"The model only supports at most
{
supported_limit
}
"
f
"
{
modality
}
items, but you passed
{
num_items
}
"
f
"
{
modality
}
items in the same prompt."
)
if
num_items
>
allowed_limit
:
raise
ValueError
(
f
"You set
{
modality
}
=
{
limit
}
(or defaulted to 1) in
"
f
"
`
--limit-mm-per-prompt`, but passed
{
len
(
items
)
}
"
f
"You set
or defaulted to
{
modality
}
=
{
allowed_limit
}
"
f
"
in
--limit-mm-per-prompt`, but passed
{
num_
items
}
"
f
"
{
modality
}
items in the same prompt."
)
return
mm_items
...
...
vllm/multimodal/profiling.py
View file @
d9fc8cd9
...
...
@@ -162,23 +162,7 @@ class MultiModalProfiler(Generic[_I]):
return
self
.
processor
.
dummy_inputs
def
get_mm_limits
(
self
)
->
Mapping
[
str
,
int
]:
mm_config
=
self
.
processing_info
.
ctx
.
get_mm_config
()
supported_mm_limits
=
self
.
processing_info
.
get_supported_mm_limits
()
mm_limits
=
{
modality
:
mm_config
.
get_limit_per_prompt
(
modality
)
for
modality
in
supported_mm_limits
}
for
modality
,
supported_limit
in
supported_mm_limits
.
items
():
limit
=
mm_limits
[
modality
]
if
supported_limit
is
not
None
and
supported_limit
<
limit
:
raise
ValueError
(
f
"You set
{
modality
}
=
{
limit
}
(or defaulted to 1) in "
f
"`--limit-mm-per-prompt`, but this model only supports "
f
"at most
{
supported_limit
}
{
modality
}
items."
)
return
mm_limits
return
self
.
processing_info
.
get_allowed_mm_limits
()
def
_get_dummy_mm_inputs
(
self
,
...
...
vllm/multimodal/registry.py
View file @
d9fc8cd9
...
...
@@ -265,8 +265,10 @@ class MultiModalRegistry:
return
profiler
.
get_mm_max_tokens
(
seq_len
,
{
modality
:
1
for
modality
in
mm_limits
},
{
modality
:
1
for
modality
,
limit
in
mm_limits
.
items
()
if
limit
>
0
},
)
return
{
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment