Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e090b7b4
Unverified
Commit
e090b7b4
authored
Sep 12, 2025
by
Maximilien de Bayser
Committed by
GitHub
Sep 12, 2025
Browse files
Enable conversion of multimodal models to pooling tasks (#24451)
Signed-off-by:
Max de Bayser
<
mbayser@br.ibm.com
>
parent
6a50eaa0
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
266 additions
and
59 deletions
+266
-59
tests/models/language/pooling/test_mm_classifier_conversion.py
.../models/language/pooling/test_mm_classifier_conversion.py
+114
-0
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+90
-52
vllm/model_executor/model_loader/utils.py
vllm/model_executor/model_loader/utils.py
+14
-4
vllm/model_executor/models/adapters.py
vllm/model_executor/models/adapters.py
+43
-2
vllm/model_executor/models/gemma3_mm.py
vllm/model_executor/models/gemma3_mm.py
+5
-1
No files found.
tests/models/language/pooling/test_mm_classifier_conversion.py
0 → 100644
View file @
e090b7b4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
vllm.platforms
import
current_platform
def
test_idefics_multimodal
(
vllm_runner
,
monkeypatch
,
)
->
None
:
if
current_platform
.
is_rocm
():
# ROCm Triton FA does not currently support sliding window attention
# switch to use ROCm CK FA backend
monkeypatch
.
setenv
(
"VLLM_USE_TRITON_FLASH_ATTN"
,
"False"
)
prompts
=
[
"Hello, my name is"
,
"The president of the United States is"
,
"The capital of France is"
,
"The future of AI is"
,
]
with
vllm_runner
(
model_name
=
"HuggingFaceM4/Idefics3-8B-Llama3"
,
runner
=
"pooling"
,
task
=
"classify"
,
convert
=
"classify"
,
load_format
=
"dummy"
,
max_model_len
=
512
,
enforce_eager
=
True
,
tensor_parallel_size
=
1
,
disable_log_stats
=
True
,
dtype
=
"bfloat16"
)
as
vllm_model
:
llm
=
vllm_model
.
get_llm
()
outputs
=
llm
.
classify
(
prompts
)
for
output
in
outputs
:
assert
len
(
output
.
outputs
.
probs
)
==
2
def
update_config
(
config
):
config
.
text_config
.
update
({
"architectures"
:
[
"Gemma3ForSequenceClassification"
],
"classifier_from_token"
:
[
"A"
,
"B"
,
"C"
,
"D"
,
"E"
],
"method"
:
"no_post_processing"
,
"id2label"
:
{
"A"
:
"Chair"
,
"B"
:
"Couch"
,
"C"
:
"Table"
,
"D"
:
"Bed"
,
"E"
:
"Cupboard"
},
})
return
config
def
test_gemma_multimodal
(
vllm_runner
,
monkeypatch
,
)
->
None
:
if
current_platform
.
is_rocm
():
# ROCm Triton FA does not currently support sliding window attention
# switch to use ROCm CK FA backend
monkeypatch
.
setenv
(
"VLLM_USE_TRITON_FLASH_ATTN"
,
"False"
)
messages
=
[{
"role"
:
"system"
,
"content"
:
"""
You are a helpful assistant. You will be given a product description
which may also include an image. Classify the following product into
one of the categories:
A = chair
B = couch
C = table
D = bed
E = cupboard
You'll answer with exactly one letter (A, B, C, D, or E)."""
},
{
"role"
:
"user"
,
"content"
:
[{
"type"
:
"image_url"
,
"image_url"
:
{
"url"
:
"https://upload.wikimedia.org/wikipedia/commons/c/c6/Set_of_fourteen_side_chairs_MET_DP110780.jpg"
}
},
{
"type"
:
"text"
,
"text"
:
"A fine 19th century piece of furniture."
}]
}]
with
vllm_runner
(
model_name
=
"google/gemma-3-4b-it"
,
runner
=
"pooling"
,
task
=
"classify"
,
convert
=
"classify"
,
load_format
=
"auto"
,
hf_overrides
=
update_config
,
override_pooler_config
=
{
"pooling_type"
:
"LAST"
},
max_model_len
=
512
,
enforce_eager
=
True
,
tensor_parallel_size
=
1
,
disable_log_stats
=
True
,
dtype
=
"bfloat16"
)
as
vllm_model
:
llm
=
vllm_model
.
get_llm
()
prompts
=
llm
.
preprocess_chat
(
messages
)
result
=
llm
.
classify
(
prompts
)
assert
result
[
0
].
outputs
.
probs
[
0
]
>
0.95
assert
all
(
c
<
0.05
for
c
in
result
[
0
].
outputs
.
probs
[
1
:])
\ No newline at end of file
vllm/entrypoints/llm.py
View file @
e090b7b4
...
@@ -703,13 +703,10 @@ class LLM:
...
@@ -703,13 +703,10 @@ class LLM:
return
outputs
return
outputs
def
chat
(
def
preprocess_
chat
(
self
,
self
,
messages
:
Union
[
list
[
ChatCompletionMessageParam
],
messages
:
Union
[
list
[
ChatCompletionMessageParam
],
list
[
list
[
ChatCompletionMessageParam
]]],
list
[
list
[
ChatCompletionMessageParam
]]],
sampling_params
:
Optional
[
Union
[
SamplingParams
,
list
[
SamplingParams
]]]
=
None
,
use_tqdm
:
Union
[
bool
,
Callable
[...,
tqdm
]]
=
True
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
chat_template
:
Optional
[
str
]
=
None
,
chat_template
:
Optional
[
str
]
=
None
,
chat_template_content_format
:
ChatTemplateContentFormatOption
=
"auto"
,
chat_template_content_format
:
ChatTemplateContentFormatOption
=
"auto"
,
...
@@ -718,56 +715,16 @@ class LLM:
...
@@ -718,56 +715,16 @@ class LLM:
tools
:
Optional
[
list
[
dict
[
str
,
Any
]]]
=
None
,
tools
:
Optional
[
list
[
dict
[
str
,
Any
]]]
=
None
,
chat_template_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
chat_template_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
mm_processor_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
mm_processor_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
)
->
list
[
RequestOutpu
t
]:
)
->
list
[
TokensPromp
t
]:
"""
"""
Generate responses for a chat conversation.
Generate prompt for a chat conversation. The pre-processed
prompt can then be used as input for the other LLM methods.
The chat conversation is converted into a text prompt using the
tokenizer and calls the [generate][vllm.LLM.generate] method to generate
the responses.
Multi-modal inputs can be passed in the same way you would pass them
to the OpenAI API.
Args:
messages: A list of conversations or a single conversation.
- Each conversation is represented as a list of messages.
- Each message is a dictionary with 'role' and 'content' keys.
sampling_params: The sampling parameters for text generation.
If None, we use the default sampling parameters. When it
is a single value, it is applied to every prompt. When it
is a list, the list must have the same length as the
prompts and it is paired one by one with the prompt.
use_tqdm: If `True`, shows a tqdm progress bar.
If a callable (e.g., `functools.partial(tqdm, leave=False)`),
it is used to create the progress bar.
If `False`, no progress bar is created.
lora_request: LoRA request to use for generation, if any.
chat_template: The template to use for structuring the chat.
If not provided, the model's default chat template will be used.
chat_template_content_format: The format to render message content.
- "string" will render the content as a string.
Example: `"Who are you?"`
- "openai" will render the content as a list of dictionaries,
similar to OpenAI schema.
Example: `[{"type": "text", "text": "Who are you?"}]`
add_generation_prompt: If True, adds a generation template
to each message.
continue_final_message: If True, continues the final message in
the conversation instead of starting a new one. Cannot be
`True` if `add_generation_prompt` is also `True`.
chat_template_kwargs: Additional kwargs to pass to the chat
template.
mm_processor_kwargs: Multimodal processor kwarg overrides for this
chat request. Only used for offline requests.
Refer to `chat` for a complete description of the arguments.
Returns:
Returns:
A list of `RequestOutput` objects containing the generated
A list of `TokensPrompts` objects containing the tokenized
responses in the same order as the input messages.
prompt after chat template interpolation, and the
pre-processed multi-modal inputs.
"""
"""
list_of_messages
:
list
[
list
[
ChatCompletionMessageParam
]]
list_of_messages
:
list
[
list
[
ChatCompletionMessageParam
]]
...
@@ -800,7 +757,7 @@ class LLM:
...
@@ -800,7 +757,7 @@ class LLM:
)
)
_chat_template_kwargs
.
update
(
chat_template_kwargs
or
{})
_chat_template_kwargs
.
update
(
chat_template_kwargs
or
{})
prompts
:
list
[
Union
[
TokensPrompt
,
TextPrompt
]
]
=
[]
prompts
:
list
[
TokensPrompt
]
=
[]
for
msgs
in
list_of_messages
:
for
msgs
in
list_of_messages
:
# NOTE: _parse_chat_message_content_parts() currently doesn't
# NOTE: _parse_chat_message_content_parts() currently doesn't
...
@@ -844,6 +801,87 @@ class LLM:
...
@@ -844,6 +801,87 @@ class LLM:
prompts
.
append
(
prompt
)
prompts
.
append
(
prompt
)
return
prompts
def
chat
(
self
,
messages
:
Union
[
list
[
ChatCompletionMessageParam
],
list
[
list
[
ChatCompletionMessageParam
]]],
sampling_params
:
Optional
[
Union
[
SamplingParams
,
list
[
SamplingParams
]]]
=
None
,
use_tqdm
:
Union
[
bool
,
Callable
[...,
tqdm
]]
=
True
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
chat_template
:
Optional
[
str
]
=
None
,
chat_template_content_format
:
ChatTemplateContentFormatOption
=
"auto"
,
add_generation_prompt
:
bool
=
True
,
continue_final_message
:
bool
=
False
,
tools
:
Optional
[
list
[
dict
[
str
,
Any
]]]
=
None
,
chat_template_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
mm_processor_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
)
->
list
[
RequestOutput
]:
"""
Generate responses for a chat conversation.
The chat conversation is converted into a text prompt using the
tokenizer and calls the [generate][vllm.LLM.generate] method to generate
the responses.
Multi-modal inputs can be passed in the same way you would pass them
to the OpenAI API.
Args:
messages: A list of conversations or a single conversation.
- Each conversation is represented as a list of messages.
- Each message is a dictionary with 'role' and 'content' keys.
sampling_params: The sampling parameters for text generation.
If None, we use the default sampling parameters. When it
is a single value, it is applied to every prompt. When it
is a list, the list must have the same length as the
prompts and it is paired one by one with the prompt.
use_tqdm: If `True`, shows a tqdm progress bar.
If a callable (e.g., `functools.partial(tqdm, leave=False)`),
it is used to create the progress bar.
If `False`, no progress bar is created.
lora_request: LoRA request to use for generation, if any.
chat_template: The template to use for structuring the chat.
If not provided, the model's default chat template will be used.
chat_template_content_format: The format to render message content.
- "string" will render the content as a string.
Example: `"Who are you?"`
- "openai" will render the content as a list of dictionaries,
similar to OpenAI schema.
Example: `[{"type": "text", "text": "Who are you?"}]`
add_generation_prompt: If True, adds a generation template
to each message.
continue_final_message: If True, continues the final message in
the conversation instead of starting a new one. Cannot be
`True` if `add_generation_prompt` is also `True`.
chat_template_kwargs: Additional kwargs to pass to the chat
template.
mm_processor_kwargs: Multimodal processor kwarg overrides for this
chat request. Only used for offline requests.
Returns:
A list of `RequestOutput` objects containing the generated
responses in the same order as the input messages.
"""
prompts
=
self
.
preprocess_chat
(
messages
=
messages
,
lora_request
=
lora_request
,
chat_template
=
chat_template
,
chat_template_content_format
=
chat_template_content_format
,
add_generation_prompt
=
add_generation_prompt
,
continue_final_message
=
continue_final_message
,
tools
=
tools
,
chat_template_kwargs
=
chat_template_kwargs
,
mm_processor_kwargs
=
mm_processor_kwargs
,
)
return
self
.
generate
(
return
self
.
generate
(
prompts
,
prompts
,
sampling_params
=
sampling_params
,
sampling_params
=
sampling_params
,
...
...
vllm/model_executor/model_loader/utils.py
View file @
e090b7b4
...
@@ -19,10 +19,11 @@ from vllm.logger import init_logger
...
@@ -19,10 +19,11 @@ from vllm.logger import init_logger
from
vllm.model_executor.layers.linear
import
QKVCrossParallelLinear
from
vllm.model_executor.layers.linear
import
QKVCrossParallelLinear
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
)
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.model_executor.models.adapters
import
(
as_embedding_model
,
from
vllm.model_executor.models.adapters
import
(
as_reward_model
,
as_embedding_model
,
as_reward_model
,
as_seq_cls_model
,
as_seq_cls_model
)
try_create_mm_pooling_model_cls
)
from
vllm.model_executor.models.interfaces
import
SupportsQuant
from
vllm.model_executor.models.interfaces
import
(
SupportsQuant
,
supports_multimodal
)
from
vllm.utils
import
is_pin_memory_available
from
vllm.utils
import
is_pin_memory_available
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -183,6 +184,15 @@ def get_model_architecture(
...
@@ -183,6 +184,15 @@ def get_model_architecture(
"performance may not be optimal."
,
arch
)
"performance may not be optimal."
,
arch
)
convert_type
=
model_config
.
convert_type
convert_type
=
model_config
.
convert_type
if
convert_type
!=
"none"
and
supports_multimodal
(
model_cls
):
logger
.
debug_once
(
"Detected conversion of Multi Modal model."
)
converted
=
try_create_mm_pooling_model_cls
(
model_cls
)
if
converted
is
not
None
:
logger
.
debug_once
(
"Creating wrapper class to forward pooler."
)
return
converted
,
arch
else
:
logger
.
debug_once
(
"Attempting direct conversion."
)
if
convert_type
==
"none"
:
if
convert_type
==
"none"
:
pass
pass
elif
convert_type
==
"embed"
:
elif
convert_type
==
"embed"
:
...
...
vllm/model_executor/models/adapters.py
View file @
e090b7b4
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
ast
import
inspect
from
collections.abc
import
Iterable
from
collections.abc
import
Iterable
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
,
TypeVar
,
cast
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
,
TypeVar
,
cast
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.models.config
import
VerifyAndUpdateConfig
from
vllm.model_executor.models.config
import
VerifyAndUpdateConfig
...
@@ -129,6 +132,41 @@ def _get_pooling_model_name(orig_model_name: str, pooling_suffix: str) -> str:
...
@@ -129,6 +132,41 @@ def _get_pooling_model_name(orig_model_name: str, pooling_suffix: str) -> str:
return
model_name
+
pooling_suffix
return
model_name
+
pooling_suffix
def
try_create_mm_pooling_model_cls
(
orig_cls
:
_T
)
->
_T
:
class
CallVisitor
(
ast
.
NodeVisitor
):
def
__init__
(
self
):
self
.
calls
=
[]
def
visit_Call
(
self
,
node
):
if
isinstance
(
node
.
func
,
ast
.
Name
):
self
.
calls
.
append
(
node
.
func
.
id
)
self
.
generic_visit
(
node
)
visitor
=
CallVisitor
()
visitor
.
visit
(
ast
.
parse
(
inspect
.
getsource
(
orig_cls
)))
if
"init_vllm_registered_model"
not
in
visitor
.
calls
:
return
None
class
ModelForPooling
(
orig_cls
,
VllmModelForPooling
):
is_pooling_model
=
True
def
__init__
(
self
,
*
,
vllm_config
:
"VllmConfig"
,
prefix
:
str
=
""
,
**
kwargs
:
Any
,
)
->
None
:
super
().
__init__
(
vllm_config
=
vllm_config
,
prefix
=
prefix
,
**
kwargs
)
self
.
pooler
=
self
.
get_language_model
().
pooler
return
ModelForPooling
# type: ignore
def
_create_pooling_model_cls
(
orig_cls
:
_T
)
->
_T
:
def
_create_pooling_model_cls
(
orig_cls
:
_T
)
->
_T
:
# Lazy import
# Lazy import
from
.utils
import
AutoWeightsLoader
,
WeightsMapper
from
.utils
import
AutoWeightsLoader
,
WeightsMapper
...
@@ -399,6 +437,7 @@ def load_weights_using_from_2_way_softmax(
...
@@ -399,6 +437,7 @@ def load_weights_using_from_2_way_softmax(
from
vllm.model_executor.models.utils
import
AutoWeightsLoader
from
vllm.model_executor.models.utils
import
AutoWeightsLoader
model_config
=
model
.
vllm_config
.
model_config
model_config
=
model
.
vllm_config
.
model_config
tokens
=
getattr
(
model
.
config
,
"classifier_from_token"
,
[])
tokens
=
getattr
(
model
.
config
,
"classifier_from_token"
,
[])
tokens
=
cast
(
list
[
int
],
tokens
)
tokens
=
cast
(
list
[
int
],
tokens
)
assert
len
(
tokens
)
==
2
assert
len
(
tokens
)
==
2
...
@@ -406,9 +445,10 @@ def load_weights_using_from_2_way_softmax(
...
@@ -406,9 +445,10 @@ def load_weights_using_from_2_way_softmax(
if
model
.
config
.
tie_word_embeddings
:
if
model
.
config
.
tie_word_embeddings
:
model
.
lm_head
=
model
.
model
.
embed_tokens
model
.
lm_head
=
model
.
model
.
embed_tokens
else
:
else
:
quant_config
=
model
.
vllm_config
.
quant_config
model
.
lm_head
=
ParallelLMHead
(
model
.
config
.
vocab_size
,
model
.
lm_head
=
ParallelLMHead
(
model
.
config
.
vocab_size
,
model
.
config
.
hidden_size
,
model
.
config
.
hidden_size
,
quant_config
=
model
.
quant_config
)
quant_config
=
quant_config
)
loader
=
AutoWeightsLoader
(
model
)
loader
=
AutoWeightsLoader
(
model
)
loaded_weights
=
loader
.
load_weights
(
weights
)
loaded_weights
=
loader
.
load_weights
(
weights
)
...
@@ -452,9 +492,10 @@ def load_weights_no_post_processing(model,
...
@@ -452,9 +492,10 @@ def load_weights_no_post_processing(model,
if
model
.
config
.
tie_word_embeddings
:
if
model
.
config
.
tie_word_embeddings
:
model
.
lm_head
=
model
.
model
.
embed_tokens
model
.
lm_head
=
model
.
model
.
embed_tokens
else
:
else
:
quant_config
=
model
.
vllm_config
.
quant_config
model
.
lm_head
=
ParallelLMHead
(
model
.
config
.
vocab_size
,
model
.
lm_head
=
ParallelLMHead
(
model
.
config
.
vocab_size
,
model
.
config
.
hidden_size
,
model
.
config
.
hidden_size
,
quant_config
=
model
.
quant_config
)
quant_config
=
quant_config
)
loader
=
AutoWeightsLoader
(
model
)
loader
=
AutoWeightsLoader
(
model
)
loaded_weights
=
loader
.
load_weights
(
weights
)
loaded_weights
=
loader
.
load_weights
(
weights
)
...
...
vllm/model_executor/models/gemma3_mm.py
View file @
e090b7b4
...
@@ -512,6 +512,10 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
...
@@ -512,6 +512,10 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
architectures
=
[
"Gemma3ForCausalLM"
],
architectures
=
[
"Gemma3ForCausalLM"
],
)
)
logit_scale
=
getattr
(
config
,
"logit_scale"
,
1.0
)
logit_scale
=
getattr
(
config
,
"logit_scale"
,
1.0
)
if
hasattr
(
self
.
language_model
,
"logits_processor"
):
# The logits processor can be unset if we're using
# automatic conversion to pooling model.
self
.
language_model
.
logits_processor
.
scale
*=
logit_scale
self
.
language_model
.
logits_processor
.
scale
*=
logit_scale
self
.
make_empty_intermediate_tensors
=
(
self
.
make_empty_intermediate_tensors
=
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment