Unverified Commit 1b886aa1 authored by Austin Veselka's avatar Austin Veselka Committed by GitHub
Browse files

[Model] Adding Support for Qwen2VL as an Embedding Model. Using MrLight/dse-qwen2-2b-mrl-v1 (#9944)


Signed-off-by: default avatarFurtherAI <austin.veselka@lighton.ai>
Co-authored-by: default avatarFurtherAI <austin.veselka@lighton.ai>
parent 3945c823
...@@ -584,6 +584,12 @@ Multimodal Embedding ...@@ -584,6 +584,12 @@ Multimodal Embedding
- :code:`TIGER-Lab/VLM2Vec-Full` - :code:`TIGER-Lab/VLM2Vec-Full`
- 🚧 - 🚧
- ✅︎ - ✅︎
* - :code:`Qwen2VLForConditionalGeneration`
- Qwen2-VL-based
- T + I
- :code:`MrLight/dse-qwen2-2b-mrl-v1`
-
- ✅︎
.. important:: .. important::
Some model architectures support both generation and embedding tasks. Some model architectures support both generation and embedding tasks.
......
...@@ -310,4 +310,21 @@ Since the request schema is not defined by OpenAI client, we post a request to t ...@@ -310,4 +310,21 @@ Since the request schema is not defined by OpenAI client, we post a request to t
response_json = response.json() response_json = response.json()
print("Embedding output:", response_json["data"][0]["embedding"]) print("Embedding output:", response_json["data"][0]["embedding"])
Here is an example for serving the ``MrLight/dse-qwen2-2b-mrl-v1`` model.
.. code-block:: bash
vllm serve MrLight/dse-qwen2-2b-mrl-v1 --task embedding \
--trust-remote-code --max-model-len 8192 --chat-template examples/template_dse_qwen2_vl.jinja
.. important::
Like with VLM2Vec, we have to explicitly pass ``--task embedding``. Additionally, ``MrLight/dse-qwen2-2b-mrl-v1`` requires an EOS token for embeddings,
which is handled by the jinja template.
.. important::
Also important, ``MrLight/dse-qwen2-2b-mrl-v1`` requires a placeholder image of the minimum image size for text query embeddings. See the full code
example below for details.
A full code example can be found in `examples/openai_chat_embedding_client_for_multimodal.py <https://github.com/vllm-project/vllm/blob/main/examples/openai_chat_embedding_client_for_multimodal.py>`_. A full code example can be found in `examples/openai_chat_embedding_client_for_multimodal.py <https://github.com/vllm-project/vllm/blob/main/examples/openai_chat_embedding_client_for_multimodal.py>`_.
import argparse
import base64
import io
import requests import requests
from PIL import Image
image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
response = requests.post(
"http://localhost:8000/v1/embeddings", def vlm2vec():
json={ response = requests.post(
"model": "http://localhost:8000/v1/embeddings",
"TIGER-Lab/VLM2Vec-Full", json={
"messages": [{ "model":
"TIGER-Lab/VLM2Vec-Full",
"messages": [{
"role":
"user",
"content": [
{
"type": "image_url",
"image_url": {
"url": image_url
}
},
{
"type": "text",
"text": "Represent the given image."
},
],
}],
"encoding_format":
"float",
},
)
response.raise_for_status()
response_json = response.json()
print("Embedding output:", response_json["data"][0]["embedding"])
def dse_qwen2_vl(inp: dict):
# Embedding an Image
if inp["dtype"] == "image":
messages = [{
"role":
"user",
"content": [{
"type": "image_url",
"image_url": {
"url": inp["image_url"],
}
}, {
"type": "text",
"text": "What is shown in this image?"
}]
}]
# Embedding a Text Query
else:
# MrLight/dse-qwen2-2b-mrl-v1 requires a placeholder image
# of the minimum input size
buffer = io.BytesIO()
image_placeholder = Image.new("RGB", (56, 56))
image_placeholder.save(buffer, "png")
buffer.seek(0)
image_placeholder = base64.b64encode(buffer.read()).decode('utf-8')
messages = [{
"role": "role":
"user", "user",
"content": [ "content": [
{ {
"type": "image_url", "type": "image_url",
"image_url": { "image_url": {
"url": image_url "url": f"data:image/jpeg;base64,{image_placeholder}",
} }
}, },
{ {
"type": "text", "type": "text",
"text": "Represent the given image." "text": f"Query: {inp['content']}"
}, },
], ]
}], }]
"encoding_format":
"float", response = requests.post(
}, "http://localhost:8000/v1/embeddings",
) json={
response.raise_for_status() "model": "MrLight/dse-qwen2-2b-mrl-v1",
response_json = response.json() "messages": messages,
"encoding_format": "float",
print("Embedding output:", response_json["data"][0]["embedding"]) },
)
response.raise_for_status()
response_json = response.json()
print("Embedding output:", response_json["data"][0]["embedding"])
if __name__ == '__main__':
parser = argparse.ArgumentParser(
"Script to call a specified VLM through the API. Make sure to serve "
"the model with --task embedding before running this.")
parser.add_argument("model",
type=str,
choices=["vlm2vec", "dse_qwen2_vl"],
required=True,
help="Which model to call.")
args = parser.parse_args()
if args.model == "vlm2vec":
vlm2vec()
elif args.model == "dse_qwen2_vl":
dse_qwen2_vl({
"dtye": "image",
"image_url": image_url,
})
dse_qwen2_vl({
"dtype": "text",
"content": "What is the weather like today?",
})
{% set image_count = namespace(value=0) %}{% set video_count = namespace(value=0) %}{% for message in messages %}{% if loop.first and message['role'] != 'system' %}{% raw %}<|im_start|>system
You are a helpful assistant.<|im_end|>
{% endraw %}{% endif %}<|im_start|>{{ message['role'] }}{% raw %}
{% endraw %}{% if message['content'] is string %}{{ message['content'] }}<|im_end|>{% raw %}
{% endraw %}{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|>{% raw %}
{% endraw %}{% endif %}{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant{% raw %}
{% endraw %}{% endif %}<|endoftext|>
\ No newline at end of file
...@@ -243,6 +243,9 @@ _T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature, dict) ...@@ -243,6 +243,9 @@ _T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature, dict)
class HfRunner: class HfRunner:
def wrap_device(self, x: _T, device: Optional[str] = None) -> _T: def wrap_device(self, x: _T, device: Optional[str] = None) -> _T:
if x is None or isinstance(x, (bool, )):
return x
if device is None: if device is None:
device = "cpu" if current_platform.is_cpu() else "cuda" device = "cpu" if current_platform.is_cpu() else "cuda"
......
from functools import partial
from typing import Callable, Dict, List, Type
import pytest
import torch
from PIL import Image
from transformers import BatchEncoding, Qwen2VLForConditionalGeneration
from ....conftest import IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner
from ....utils import large_gpu_test
from ..utils import check_embeddings_close
HF_TEXT_PROMPTS = [
# T -> X
(
"Query: Find me an everyday image that matches the given caption: The label of the object is stop sign", # noqa: E501,
Image.new("RGB", (56, 56))),
# T -> X
("Query: Retrieve an image of this caption: cherry blossom",
Image.new("RGB", (56, 56))),
]
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
"stop_sign":
"What is shown in this image?",
"cherry_blossom":
"What is shown in this image?"
})
MODELS = ["MrLight/dse-qwen2-2b-mrl-v1"]
def get_messages(image: Image.Image, text: str, embed_text: bool):
# assert False, 'remember to use outer [] as required'
if embed_text:
messages = [{
"role":
"user",
"content": [
{
"type": "image",
"image": Image.new("RGB", (56, 56)),
"resized_height": 1,
"resized_width": 1
}, # need a dummy image here for an easier process.
{
"type": "text",
"text": text
},
]
}]
else:
messages = [{
"role":
"user",
"content": [{
"type": "image",
"image": image
}, {
"type": "text",
"text": text
}]
}]
return messages
def apply_chat_template_and_add_eos(
messages: List[Dict],
apply_chat_template_fn: Callable,
):
prompt = apply_chat_template_fn(
messages, tokenize=False, add_generation_prompt=True) + "<|endoftext|>"
return prompt
def postprocess_inputs(hf_model: HfRunner, inputs: BatchEncoding, **kwargs):
return hf_model.model.prepare_inputs_for_generation(**inputs, **kwargs)
def _run_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
input_texts: List[str],
input_images: PromptImageInput,
embed_texts: List[bool],
model: str,
*,
dtype: str,
) -> None:
'''SET PYTHONPATH'''
# NOTE: take care of the order. run vLLM first, and then run HF.
# vLLM needs a fresh new process without cuda initialization.
# if we run HF first, the cuda initialization will be done and it
# will hurt multiprocessing backend with fork method (the default method).
with vllm_runner(model,
task="embedding",
dtype=dtype,
enforce_eager=True,
max_model_len=8192) as vllm_model:
tokenizer = vllm_model.model.get_tokenizer()
texts = [
# this is necessary because vllm_model.encode will not apply any
# templating to the prompt, and therefore lacks an image_pad
# token unless one is inserted beforehand (the (28,28) image
# above is converted to an image pad token by the chat template).
apply_chat_template_and_add_eos(
get_messages(image, text, False),
apply_chat_template_fn=tokenizer.apply_chat_template,
) for text, image in zip(input_texts, input_images)
# vllm will replace the pad token with the actual image,
# which may be a placeholder image, later.
]
vllm_outputs = vllm_model.encode(texts, images=input_images)
hf_outputs = []
with hf_runner(model,
dtype=dtype,
auto_cls=Qwen2VLForConditionalGeneration) as hf_model:
hf_model.postprocess_inputs = partial(
postprocess_inputs,
hf_model,
cache_position=torch.arange(
0,
1, # 1 for batch size
requires_grad=False),
use_cache=False)
for text, image, embed_text in zip(input_texts, input_images,
embed_texts):
# dse requires non-standard input processing
# because it needs an image_pad token
messages = get_messages(image, text, embed_text)
prompt = apply_chat_template_and_add_eos(
messages, hf_model.processor.apply_chat_template)
inputs = hf_model.get_inputs(
prompts=[[prompt]],
images=[[image]],
)
with torch.no_grad():
outputs = hf_model.model(
**hf_model.wrap_device(inputs[0],
device=hf_model.model.device.type),
return_dict=True,
output_hidden_states=True,
)
pooled_output = torch.nn.functional.normalize(
outputs.hidden_states[-1][0, -1], p=2, dim=-1)
hf_outputs.append(pooled_output.tolist())
check_embeddings_close(
embeddings_0_lst=hf_outputs,
embeddings_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"])
def test_models_text(
hf_runner,
vllm_runner,
image_assets,
model: str,
dtype: str,
) -> None:
input_texts_images = [(text, image_placeholder)
for text, image_placeholder in HF_TEXT_PROMPTS]
input_texts = [text for text, _ in input_texts_images]
input_images = [image for _, image in input_texts_images]
embed_texts = [True] * len(input_texts)
_run_test(
hf_runner,
vllm_runner,
input_texts,
input_images, # type: ignore
embed_texts,
model,
dtype=dtype,
)
@large_gpu_test(min_gb=48)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"])
def test_models_image(
hf_runner,
vllm_runner,
image_assets,
model: str,
dtype: str,
) -> None:
input_texts_images = [
(text, asset.pil_image)
for text, asset in zip(HF_IMAGE_PROMPTS, image_assets)
]
input_texts = [text for text, _ in input_texts_images]
input_images = [image for _, image in input_texts_images]
embed_texts = [False] * len(input_texts)
_run_test(
hf_runner,
vllm_runner,
input_texts,
input_images,
embed_texts,
model,
dtype=dtype,
)
...@@ -51,6 +51,7 @@ from vllm.model_executor.layers.activation import QuickGELU ...@@ -51,6 +51,7 @@ from vllm.model_executor.layers.activation import QuickGELU
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.quantization import (GPTQConfig, from vllm.model_executor.layers.quantization import (GPTQConfig,
GPTQMarlinConfig, GPTQMarlinConfig,
QuantizationConfig) QuantizationConfig)
...@@ -58,12 +59,13 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler ...@@ -58,12 +59,13 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.qwen2 import Qwen2Model from vllm.model_executor.models.qwen2 import Qwen2Model
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict, from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict,
MultiModalKwargs) MultiModalKwargs)
from vllm.multimodal.base import MultiModalData from vllm.multimodal.base import MultiModalData
from vllm.multimodal.image import cached_get_image_processor from vllm.multimodal.image import cached_get_image_processor
from vllm.multimodal.utils import cached_get_tokenizer from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import IntermediateTensors, SequenceData from vllm.sequence import IntermediateTensors, PoolerOutput, SequenceData
from vllm.transformers_utils.config import uses_mrope from vllm.transformers_utils.config import uses_mrope
from vllm.transformers_utils.processor import cached_get_processor from vllm.transformers_utils.processor import cached_get_processor
...@@ -1067,6 +1069,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -1067,6 +1069,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
pooler_config = vllm_config.model_config.pooler_config
multimodal_config = vllm_config.model_config.multimodal_config multimodal_config = vllm_config.model_config.multimodal_config
assert not cache_config.enable_prefix_caching, \ assert not cache_config.enable_prefix_caching, \
"Qwen2-VL currently does not support prefix caching" "Qwen2-VL currently does not support prefix caching"
...@@ -1098,6 +1101,11 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -1098,6 +1101,11 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = get_sampler() self.sampler = get_sampler()
self._pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.LAST,
normalize=True,
softmax=False)
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory( make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size)) ["hidden_states", "residual"], config.hidden_size))
...@@ -1318,6 +1326,13 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -1318,6 +1326,13 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
next_tokens = self.sampler(logits, sampling_metadata) next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens return next_tokens
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
......
...@@ -109,6 +109,7 @@ _EMBEDDING_MODELS = { ...@@ -109,6 +109,7 @@ _EMBEDDING_MODELS = {
# [Multimodal] # [Multimodal]
"LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501 "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"), "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
"Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration") # noqa: E501,
} }
_MULTIMODAL_MODELS = { _MULTIMODAL_MODELS = {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment