Unverified Commit e090b7b4 authored by Maximilien de Bayser's avatar Maximilien de Bayser Committed by GitHub
Browse files

Enable conversion of multimodal models to pooling tasks (#24451)


Signed-off-by: default avatarMax de Bayser <mbayser@br.ibm.com>
parent 6a50eaa0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.platforms import current_platform
def test_idefics_multimodal(
vllm_runner,
monkeypatch,
) -> None:
if current_platform.is_rocm():
# ROCm Triton FA does not currently support sliding window attention
# switch to use ROCm CK FA backend
monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", "False")
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
with vllm_runner(model_name="HuggingFaceM4/Idefics3-8B-Llama3",
runner="pooling",
task="classify",
convert="classify",
load_format="dummy",
max_model_len=512,
enforce_eager=True,
tensor_parallel_size=1,
disable_log_stats=True,
dtype="bfloat16") as vllm_model:
llm = vllm_model.get_llm()
outputs = llm.classify(prompts)
for output in outputs:
assert len(output.outputs.probs) == 2
def update_config(config):
config.text_config.update({
"architectures": ["Gemma3ForSequenceClassification"],
"classifier_from_token": ["A", "B", "C", "D", "E"],
"method":
"no_post_processing",
"id2label": {
"A": "Chair",
"B": "Couch",
"C": "Table",
"D": "Bed",
"E": "Cupboard"
},
})
return config
def test_gemma_multimodal(
vllm_runner,
monkeypatch,
) -> None:
if current_platform.is_rocm():
# ROCm Triton FA does not currently support sliding window attention
# switch to use ROCm CK FA backend
monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", "False")
messages = [{
"role":
"system",
"content":
"""
You are a helpful assistant. You will be given a product description
which may also include an image. Classify the following product into
one of the categories:
A = chair
B = couch
C = table
D = bed
E = cupboard
You'll answer with exactly one letter (A, B, C, D, or E)."""
}, {
"role":
"user",
"content": [{
"type": "image_url",
"image_url": {
"url":
"https://upload.wikimedia.org/wikipedia/commons/c/c6/Set_of_fourteen_side_chairs_MET_DP110780.jpg"
}
}, {
"type": "text",
"text": "A fine 19th century piece of furniture."
}]
}]
with vllm_runner(model_name="google/gemma-3-4b-it",
runner="pooling",
task="classify",
convert="classify",
load_format="auto",
hf_overrides=update_config,
override_pooler_config={"pooling_type": "LAST"},
max_model_len=512,
enforce_eager=True,
tensor_parallel_size=1,
disable_log_stats=True,
dtype="bfloat16") as vllm_model:
llm = vllm_model.get_llm()
prompts = llm.preprocess_chat(messages)
result = llm.classify(prompts)
assert result[0].outputs.probs[0] > 0.95
assert all(c < 0.05 for c in result[0].outputs.probs[1:])
\ No newline at end of file
...@@ -703,13 +703,10 @@ class LLM: ...@@ -703,13 +703,10 @@ class LLM:
return outputs return outputs
def chat( def preprocess_chat(
self, self,
messages: Union[list[ChatCompletionMessageParam], messages: Union[list[ChatCompletionMessageParam],
list[list[ChatCompletionMessageParam]]], list[list[ChatCompletionMessageParam]]],
sampling_params: Optional[Union[SamplingParams,
list[SamplingParams]]] = None,
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
chat_template: Optional[str] = None, chat_template: Optional[str] = None,
chat_template_content_format: ChatTemplateContentFormatOption = "auto", chat_template_content_format: ChatTemplateContentFormatOption = "auto",
...@@ -718,56 +715,16 @@ class LLM: ...@@ -718,56 +715,16 @@ class LLM:
tools: Optional[list[dict[str, Any]]] = None, tools: Optional[list[dict[str, Any]]] = None,
chat_template_kwargs: Optional[dict[str, Any]] = None, chat_template_kwargs: Optional[dict[str, Any]] = None,
mm_processor_kwargs: Optional[dict[str, Any]] = None, mm_processor_kwargs: Optional[dict[str, Any]] = None,
) -> list[RequestOutput]: ) -> list[TokensPrompt]:
""" """
Generate responses for a chat conversation. Generate prompt for a chat conversation. The pre-processed
prompt can then be used as input for the other LLM methods.
The chat conversation is converted into a text prompt using the
tokenizer and calls the [generate][vllm.LLM.generate] method to generate
the responses.
Multi-modal inputs can be passed in the same way you would pass them
to the OpenAI API.
Args:
messages: A list of conversations or a single conversation.
- Each conversation is represented as a list of messages.
- Each message is a dictionary with 'role' and 'content' keys.
sampling_params: The sampling parameters for text generation.
If None, we use the default sampling parameters. When it
is a single value, it is applied to every prompt. When it
is a list, the list must have the same length as the
prompts and it is paired one by one with the prompt.
use_tqdm: If `True`, shows a tqdm progress bar.
If a callable (e.g., `functools.partial(tqdm, leave=False)`),
it is used to create the progress bar.
If `False`, no progress bar is created.
lora_request: LoRA request to use for generation, if any.
chat_template: The template to use for structuring the chat.
If not provided, the model's default chat template will be used.
chat_template_content_format: The format to render message content.
- "string" will render the content as a string.
Example: `"Who are you?"`
- "openai" will render the content as a list of dictionaries,
similar to OpenAI schema.
Example: `[{"type": "text", "text": "Who are you?"}]`
add_generation_prompt: If True, adds a generation template
to each message.
continue_final_message: If True, continues the final message in
the conversation instead of starting a new one. Cannot be
`True` if `add_generation_prompt` is also `True`.
chat_template_kwargs: Additional kwargs to pass to the chat
template.
mm_processor_kwargs: Multimodal processor kwarg overrides for this
chat request. Only used for offline requests.
Refer to `chat` for a complete description of the arguments.
Returns: Returns:
A list of `RequestOutput` objects containing the generated A list of `TokensPrompts` objects containing the tokenized
responses in the same order as the input messages. prompt after chat template interpolation, and the
pre-processed multi-modal inputs.
""" """
list_of_messages: list[list[ChatCompletionMessageParam]] list_of_messages: list[list[ChatCompletionMessageParam]]
...@@ -800,7 +757,7 @@ class LLM: ...@@ -800,7 +757,7 @@ class LLM:
) )
_chat_template_kwargs.update(chat_template_kwargs or {}) _chat_template_kwargs.update(chat_template_kwargs or {})
prompts: list[Union[TokensPrompt, TextPrompt]] = [] prompts: list[TokensPrompt] = []
for msgs in list_of_messages: for msgs in list_of_messages:
# NOTE: _parse_chat_message_content_parts() currently doesn't # NOTE: _parse_chat_message_content_parts() currently doesn't
...@@ -844,6 +801,87 @@ class LLM: ...@@ -844,6 +801,87 @@ class LLM:
prompts.append(prompt) prompts.append(prompt)
return prompts
def chat(
self,
messages: Union[list[ChatCompletionMessageParam],
list[list[ChatCompletionMessageParam]]],
sampling_params: Optional[Union[SamplingParams,
list[SamplingParams]]] = None,
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[LoRARequest] = None,
chat_template: Optional[str] = None,
chat_template_content_format: ChatTemplateContentFormatOption = "auto",
add_generation_prompt: bool = True,
continue_final_message: bool = False,
tools: Optional[list[dict[str, Any]]] = None,
chat_template_kwargs: Optional[dict[str, Any]] = None,
mm_processor_kwargs: Optional[dict[str, Any]] = None,
) -> list[RequestOutput]:
"""
Generate responses for a chat conversation.
The chat conversation is converted into a text prompt using the
tokenizer and calls the [generate][vllm.LLM.generate] method to generate
the responses.
Multi-modal inputs can be passed in the same way you would pass them
to the OpenAI API.
Args:
messages: A list of conversations or a single conversation.
- Each conversation is represented as a list of messages.
- Each message is a dictionary with 'role' and 'content' keys.
sampling_params: The sampling parameters for text generation.
If None, we use the default sampling parameters. When it
is a single value, it is applied to every prompt. When it
is a list, the list must have the same length as the
prompts and it is paired one by one with the prompt.
use_tqdm: If `True`, shows a tqdm progress bar.
If a callable (e.g., `functools.partial(tqdm, leave=False)`),
it is used to create the progress bar.
If `False`, no progress bar is created.
lora_request: LoRA request to use for generation, if any.
chat_template: The template to use for structuring the chat.
If not provided, the model's default chat template will be used.
chat_template_content_format: The format to render message content.
- "string" will render the content as a string.
Example: `"Who are you?"`
- "openai" will render the content as a list of dictionaries,
similar to OpenAI schema.
Example: `[{"type": "text", "text": "Who are you?"}]`
add_generation_prompt: If True, adds a generation template
to each message.
continue_final_message: If True, continues the final message in
the conversation instead of starting a new one. Cannot be
`True` if `add_generation_prompt` is also `True`.
chat_template_kwargs: Additional kwargs to pass to the chat
template.
mm_processor_kwargs: Multimodal processor kwarg overrides for this
chat request. Only used for offline requests.
Returns:
A list of `RequestOutput` objects containing the generated
responses in the same order as the input messages.
"""
prompts = self.preprocess_chat(
messages=messages,
lora_request=lora_request,
chat_template=chat_template,
chat_template_content_format=chat_template_content_format,
add_generation_prompt=add_generation_prompt,
continue_final_message=continue_final_message,
tools=tools,
chat_template_kwargs=chat_template_kwargs,
mm_processor_kwargs=mm_processor_kwargs,
)
return self.generate( return self.generate(
prompts, prompts,
sampling_params=sampling_params, sampling_params=sampling_params,
......
...@@ -19,10 +19,11 @@ from vllm.logger import init_logger ...@@ -19,10 +19,11 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.linear import QKVCrossParallelLinear from vllm.model_executor.layers.linear import QKVCrossParallelLinear
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase) QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.models.adapters import (as_embedding_model, from vllm.model_executor.models.adapters import (
as_reward_model, as_embedding_model, as_reward_model, as_seq_cls_model,
as_seq_cls_model) try_create_mm_pooling_model_cls)
from vllm.model_executor.models.interfaces import SupportsQuant from vllm.model_executor.models.interfaces import (SupportsQuant,
supports_multimodal)
from vllm.utils import is_pin_memory_available from vllm.utils import is_pin_memory_available
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -183,6 +184,15 @@ def get_model_architecture( ...@@ -183,6 +184,15 @@ def get_model_architecture(
"performance may not be optimal.", arch) "performance may not be optimal.", arch)
convert_type = model_config.convert_type convert_type = model_config.convert_type
if convert_type != "none" and supports_multimodal(model_cls):
logger.debug_once("Detected conversion of Multi Modal model.")
converted = try_create_mm_pooling_model_cls(model_cls)
if converted is not None:
logger.debug_once("Creating wrapper class to forward pooler.")
return converted, arch
else:
logger.debug_once("Attempting direct conversion.")
if convert_type == "none": if convert_type == "none":
pass pass
elif convert_type == "embed": elif convert_type == "embed":
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import ast
import inspect
from collections.abc import Iterable from collections.abc import Iterable
from typing import TYPE_CHECKING, Any, Optional, TypeVar, cast from typing import TYPE_CHECKING, Any, Optional, TypeVar, cast
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.models.config import VerifyAndUpdateConfig from vllm.model_executor.models.config import VerifyAndUpdateConfig
...@@ -129,6 +132,41 @@ def _get_pooling_model_name(orig_model_name: str, pooling_suffix: str) -> str: ...@@ -129,6 +132,41 @@ def _get_pooling_model_name(orig_model_name: str, pooling_suffix: str) -> str:
return model_name + pooling_suffix return model_name + pooling_suffix
def try_create_mm_pooling_model_cls(orig_cls: _T) -> _T:
class CallVisitor(ast.NodeVisitor):
def __init__(self):
self.calls = []
def visit_Call(self, node):
if isinstance(node.func, ast.Name):
self.calls.append(node.func.id)
self.generic_visit(node)
visitor = CallVisitor()
visitor.visit(ast.parse(inspect.getsource(orig_cls)))
if "init_vllm_registered_model" not in visitor.calls:
return None
class ModelForPooling(orig_cls, VllmModelForPooling):
is_pooling_model = True
def __init__(
self,
*,
vllm_config: "VllmConfig",
prefix: str = "",
**kwargs: Any,
) -> None:
super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)
self.pooler = self.get_language_model().pooler
return ModelForPooling # type: ignore
def _create_pooling_model_cls(orig_cls: _T) -> _T: def _create_pooling_model_cls(orig_cls: _T) -> _T:
# Lazy import # Lazy import
from .utils import AutoWeightsLoader, WeightsMapper from .utils import AutoWeightsLoader, WeightsMapper
...@@ -399,6 +437,7 @@ def load_weights_using_from_2_way_softmax( ...@@ -399,6 +437,7 @@ def load_weights_using_from_2_way_softmax(
from vllm.model_executor.models.utils import AutoWeightsLoader from vllm.model_executor.models.utils import AutoWeightsLoader
model_config = model.vllm_config.model_config model_config = model.vllm_config.model_config
tokens = getattr(model.config, "classifier_from_token", []) tokens = getattr(model.config, "classifier_from_token", [])
tokens = cast(list[int], tokens) tokens = cast(list[int], tokens)
assert len(tokens) == 2 assert len(tokens) == 2
...@@ -406,9 +445,10 @@ def load_weights_using_from_2_way_softmax( ...@@ -406,9 +445,10 @@ def load_weights_using_from_2_way_softmax(
if model.config.tie_word_embeddings: if model.config.tie_word_embeddings:
model.lm_head = model.model.embed_tokens model.lm_head = model.model.embed_tokens
else: else:
quant_config = model.vllm_config.quant_config
model.lm_head = ParallelLMHead(model.config.vocab_size, model.lm_head = ParallelLMHead(model.config.vocab_size,
model.config.hidden_size, model.config.hidden_size,
quant_config=model.quant_config) quant_config=quant_config)
loader = AutoWeightsLoader(model) loader = AutoWeightsLoader(model)
loaded_weights = loader.load_weights(weights) loaded_weights = loader.load_weights(weights)
...@@ -452,9 +492,10 @@ def load_weights_no_post_processing(model, ...@@ -452,9 +492,10 @@ def load_weights_no_post_processing(model,
if model.config.tie_word_embeddings: if model.config.tie_word_embeddings:
model.lm_head = model.model.embed_tokens model.lm_head = model.model.embed_tokens
else: else:
quant_config = model.vllm_config.quant_config
model.lm_head = ParallelLMHead(model.config.vocab_size, model.lm_head = ParallelLMHead(model.config.vocab_size,
model.config.hidden_size, model.config.hidden_size,
quant_config=model.quant_config) quant_config=quant_config)
loader = AutoWeightsLoader(model) loader = AutoWeightsLoader(model)
loaded_weights = loader.load_weights(weights) loaded_weights = loader.load_weights(weights)
......
...@@ -512,7 +512,11 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, ...@@ -512,7 +512,11 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
architectures=["Gemma3ForCausalLM"], architectures=["Gemma3ForCausalLM"],
) )
logit_scale = getattr(config, "logit_scale", 1.0) logit_scale = getattr(config, "logit_scale", 1.0)
self.language_model.logits_processor.scale *= logit_scale
if hasattr(self.language_model, "logits_processor"):
# The logits processor can be unset if we're using
# automatic conversion to pooling model.
self.language_model.logits_processor.scale *= logit_scale
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors) self.language_model.make_empty_intermediate_tensors)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment