Unverified Commit 60446cd6 authored by wang.yuqi's avatar wang.yuqi Committed by GitHub
Browse files

[Model] Improve multimodal pooling examples (#32085)


Signed-off-by: default avatarwang.yuqi <noooop@126.com>
Signed-off-by: default avatarwang.yuqi <yuqi.wang@daocloud.io>
parent 9101dc75
...@@ -362,7 +362,7 @@ and passing a list of `messages` in the request. Refer to the examples below for ...@@ -362,7 +362,7 @@ and passing a list of `messages` in the request. Refer to the examples below for
`MrLight/dse-qwen2-2b-mrl-v1` requires a placeholder image of the minimum image size for text query embeddings. See the full code `MrLight/dse-qwen2-2b-mrl-v1` requires a placeholder image of the minimum image size for text query embeddings. See the full code
example below for details. example below for details.
Full example: [examples/pooling/embed/openai_chat_embedding_client_for_multimodal.py](../../examples/pooling/embed/openai_chat_embedding_client_for_multimodal.py) Full example: [examples/pooling/embed/vision_embedding_online.py](../../examples/pooling/embed/vision_embedding_online.py)
#### Extra parameters #### Extra parameters
...@@ -667,7 +667,7 @@ Usually, the score for a sentence pair refers to the similarity between two sent ...@@ -667,7 +667,7 @@ Usually, the score for a sentence pair refers to the similarity between two sent
You can find the documentation for cross encoder models at [sbert.net](https://www.sbert.net/docs/package_reference/cross_encoder/cross_encoder.html). You can find the documentation for cross encoder models at [sbert.net](https://www.sbert.net/docs/package_reference/cross_encoder/cross_encoder.html).
Code example: [examples/pooling/score/openai_cross_encoder_score.py](../../examples/pooling/score/openai_cross_encoder_score.py) Code example: [examples/pooling/score/score_api_online.py](../../examples/pooling/score/score_api_online.py)
#### Score Template #### Score Template
...@@ -863,7 +863,10 @@ You can pass multi-modal inputs to scoring models by passing `content` including ...@@ -863,7 +863,10 @@ You can pass multi-modal inputs to scoring models by passing `content` including
print("Scoring output:", response_json["data"][0]["score"]) print("Scoring output:", response_json["data"][0]["score"])
print("Scoring output:", response_json["data"][1]["score"]) print("Scoring output:", response_json["data"][1]["score"])
``` ```
Full example: [examples/pooling/score/openai_cross_encoder_score_for_multimodal.py](../../examples/pooling/score/openai_cross_encoder_score_for_multimodal.py) Full example:
- [examples/pooling/score/vision_score_api_online.py](../../examples/pooling/score/vision_score_api_online.py)
- examples/pooling/score/vision_rerank_api_online.py](../../examples/pooling/score/vision_rerank_api_online.py)
#### Extra parameters #### Extra parameters
...@@ -893,7 +896,7 @@ endpoints are compatible with both [Jina AI's re-rank API interface](https://jin ...@@ -893,7 +896,7 @@ endpoints are compatible with both [Jina AI's re-rank API interface](https://jin
[Cohere's re-rank API interface](https://docs.cohere.com/v2/reference/rerank) to ensure compatibility with [Cohere's re-rank API interface](https://docs.cohere.com/v2/reference/rerank) to ensure compatibility with
popular open-source tools. popular open-source tools.
Code example: [examples/pooling/score/openai_reranker.py](../../examples/pooling/score/openai_reranker.py) Code example: [examples/pooling/score/rerank_api_online.py](../../examples/pooling/score/rerank_api_online.py)
#### Example Request #### Example Request
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa: E501
"""
This example shows how to use vLLM for running offline inference with
the correct prompt format on vision language models for multimodal embedding.
For most models, the prompt format should follow corresponding examples
on HuggingFace model repository.
"""
import argparse
from dataclasses import asdict
from vllm import LLM, EngineArgs
from vllm.multimodal.utils import fetch_image
image_url = "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/cat_snow.jpg"
text = "A cat standing in the snow."
multi_modal_data = {"image": fetch_image(image_url)}
def print_embeddings(embeds):
embeds_trimmed = (str(embeds[:4])[:-1] + ", ...]") if len(embeds) > 4 else embeds
print(f"Embeddings: {embeds_trimmed} (size={len(embeds)})")
def run_qwen3_vl():
engine_args = EngineArgs(
model="Qwen/Qwen3-VL-Embedding-2B",
runner="pooling",
max_model_len=8192,
limit_mm_per_prompt={"image": 1},
)
default_instruction = "Represent the user's input."
image_placeholder = "<|vision_start|><|image_pad|><|vision_end|>"
text_prompt = f"<|im_start|>system\n{default_instruction}<|im_end|>\n<|im_start|>user\n{text}<|im_end|>\n<|im_start|>assistant\n"
image_prompt = f"<|im_start|>system\n{default_instruction}<|im_end|>\n<|im_start|>user\n{image_placeholder}<|im_end|>\n<|im_start|>assistant\n"
image_text_prompt = f"<|im_start|>system\n{default_instruction}<|im_end|>\n<|im_start|>user\n{image_placeholder}{text}<|im_end|>\n<|im_start|>assistant\n"
llm = LLM(**asdict(engine_args))
print("Text embedding output:")
outputs = llm.embed(text_prompt, use_tqdm=False)
print_embeddings(outputs[0].outputs.embedding)
print("Image embedding output:")
outputs = llm.embed(
{
"prompt": image_prompt,
"multi_modal_data": multi_modal_data,
},
use_tqdm=False,
)
print_embeddings(outputs[0].outputs.embedding)
print("Image+Text embedding output:")
outputs = llm.embed(
{
"prompt": image_text_prompt,
"multi_modal_data": multi_modal_data,
},
use_tqdm=False,
)
print_embeddings(outputs[0].outputs.embedding)
model_example_map = {
"qwen3_vl": run_qwen3_vl,
}
def parse_args():
parser = argparse.ArgumentParser(
"Script to run a specified VLM through vLLM offline api."
)
parser.add_argument(
"--model",
type=str,
choices=model_example_map.keys(),
required=True,
help="The name of the embedding model.",
)
return parser.parse_args()
def main(args):
model_example_map[args.model]()
if __name__ == "__main__":
args = parse_args()
main(args)
...@@ -21,7 +21,8 @@ from PIL import Image ...@@ -21,7 +21,8 @@ from PIL import Image
openai_api_key = "EMPTY" openai_api_key = "EMPTY"
openai_api_base = "http://localhost:8000/v1" openai_api_base = "http://localhost:8000/v1"
image_url = "https://vllm-public-assets.s3.us-west-2.amazonaws.com/vision_model_images/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" image_url = "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/cat_snow.jpg"
text = "A cat standing in the snow."
def create_chat_embeddings( def create_chat_embeddings(
...@@ -30,6 +31,8 @@ def create_chat_embeddings( ...@@ -30,6 +31,8 @@ def create_chat_embeddings(
messages: list[ChatCompletionMessageParam], messages: list[ChatCompletionMessageParam],
model: str, model: str,
encoding_format: Literal["base64", "float"] | NotGiven = NOT_GIVEN, encoding_format: Literal["base64", "float"] | NotGiven = NOT_GIVEN,
continue_final_message: bool = False,
add_special_tokens: bool = False,
) -> CreateEmbeddingResponse: ) -> CreateEmbeddingResponse:
""" """
Convenience function for accessing vLLM's Chat Embeddings API, Convenience function for accessing vLLM's Chat Embeddings API,
...@@ -38,10 +41,21 @@ def create_chat_embeddings( ...@@ -38,10 +41,21 @@ def create_chat_embeddings(
return client.post( return client.post(
"/embeddings", "/embeddings",
cast_to=CreateEmbeddingResponse, cast_to=CreateEmbeddingResponse,
body={"messages": messages, "model": model, "encoding_format": encoding_format}, body={
"messages": messages,
"model": model,
"encoding_format": encoding_format,
"continue_final_message": continue_final_message,
"add_special_tokens": add_special_tokens,
},
) )
def print_embeddings(embeds):
embeds_trimmed = (str(embeds[:4])[:-1] + ", ...]") if len(embeds) > 4 else embeds
print(f"Embeddings: {embeds_trimmed} (size={len(embeds)})")
def run_clip(client: OpenAI, model: str): def run_clip(client: OpenAI, model: str):
""" """
Start the server using: Start the server using:
...@@ -145,6 +159,113 @@ def run_dse_qwen2_vl(client: OpenAI, model: str): ...@@ -145,6 +159,113 @@ def run_dse_qwen2_vl(client: OpenAI, model: str):
print("Text embedding output:", response.data[0].embedding) print("Text embedding output:", response.data[0].embedding)
def run_qwen3_vl(client: OpenAI, model: str):
"""
Start the server using:
vllm serve Qwen/Qwen3-VL-Embedding-2B \
--runner pooling \
--max-model-len 8192
"""
default_instruction = "Represent the user's input."
print("Text embedding output:")
response = create_chat_embeddings(
client,
messages=[
{
"role": "system",
"content": [
{"type": "text", "text": default_instruction},
],
},
{
"role": "user",
"content": [
{"type": "text", "text": text},
],
},
{
"role": "assistant",
"content": [
{"type": "text", "text": ""},
],
},
],
model=model,
encoding_format="float",
continue_final_message=True,
add_special_tokens=True,
)
print_embeddings(response.data[0].embedding)
print("Image embedding output:")
response = create_chat_embeddings(
client,
messages=[
{
"role": "system",
"content": [
{"type": "text", "text": default_instruction},
],
},
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": image_url}},
{"type": "text", "text": ""},
],
},
{
"role": "assistant",
"content": [
{"type": "text", "text": ""},
],
},
],
model=model,
encoding_format="float",
continue_final_message=True,
add_special_tokens=True,
)
print_embeddings(response.data[0].embedding)
print("Image+Text embedding output:")
response = create_chat_embeddings(
client,
messages=[
{
"role": "system",
"content": [
{"type": "text", "text": default_instruction},
],
},
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": image_url}},
{
"type": "text",
"text": f"{text}",
},
],
},
{
"role": "assistant",
"content": [
{"type": "text", "text": ""},
],
},
],
model=model,
encoding_format="float",
continue_final_message=True,
add_special_tokens=True,
)
print_embeddings(response.data[0].embedding)
def run_siglip(client: OpenAI, model: str): def run_siglip(client: OpenAI, model: str):
""" """
Start the server using: Start the server using:
...@@ -213,7 +334,8 @@ def run_vlm2vec(client: OpenAI, model: str): ...@@ -213,7 +334,8 @@ def run_vlm2vec(client: OpenAI, model: str):
encoding_format="float", encoding_format="float",
) )
print("Image embedding output:", response.data[0].embedding) print("Image embedding output:")
print_embeddings(response.data[0].embedding)
response = create_chat_embeddings( response = create_chat_embeddings(
client, client,
...@@ -233,7 +355,8 @@ def run_vlm2vec(client: OpenAI, model: str): ...@@ -233,7 +355,8 @@ def run_vlm2vec(client: OpenAI, model: str):
encoding_format="float", encoding_format="float",
) )
print("Image+Text embedding output:", response.data[0].embedding) print("Image+Text embedding output:")
print_embeddings(response.data[0].embedding)
response = create_chat_embeddings( response = create_chat_embeddings(
client, client,
...@@ -249,11 +372,13 @@ def run_vlm2vec(client: OpenAI, model: str): ...@@ -249,11 +372,13 @@ def run_vlm2vec(client: OpenAI, model: str):
encoding_format="float", encoding_format="float",
) )
print("Text embedding output:", response.data[0].embedding) print("Text embedding output:")
print_embeddings(response.data[0].embedding)
model_example_map = { model_example_map = {
"clip": run_clip, "clip": run_clip,
"qwen3_vl": run_qwen3_vl,
"dse_qwen2_vl": run_dse_qwen2_vl, "dse_qwen2_vl": run_dse_qwen2_vl,
"siglip": run_siglip, "siglip": run_siglip,
"vlm2vec": run_vlm2vec, "vlm2vec": run_vlm2vec,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa: E501
"""
Example Python client for multimodal rerank API which is compatible with
Jina and Cohere https://jina.ai/reranker
Run `vllm serve <model> --runner pooling` to start up the server in vLLM.
e.g.
vllm serve jinaai/jina-reranker-m0 --runner pooling
vllm serve Qwen/Qwen3-VL-Reranker-2B \
--runner pooling \
--max-model-len 4096 \
--hf_overrides '{"architectures": ["Qwen3VLForSequenceClassification"],"classifier_from_token": ["no", "yes"],"is_original_qwen3_reranker": true}' \
--chat-template examples/pooling/score/template/qwen3_vl_reranker.jinja
"""
import argparse
import json
import requests
headers = {"accept": "application/json", "Content-Type": "application/json"}
query = "A woman playing with her dog on a beach at sunset."
documents = {
"content": [
{
"type": "text",
"text": (
"A woman shares a joyful moment with her golden retriever on a sun-drenched beach at sunset, " # noqa: E501
"as the dog offers its paw in a heartwarming display of companionship and trust." # noqa: E501
),
},
{
"type": "image_url",
"image_url": {
"url": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg"
},
},
]
}
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=8000)
return parser.parse_args()
def main(args):
base_url = f"http://{args.host}:{args.port}"
models_url = base_url + "/v1/models"
rerank_url = base_url + "/rerank"
response = requests.get(models_url, headers=headers)
model = response.json()["data"][0]["id"]
data = {
"model": model,
"query": query,
"documents": documents,
}
response = requests.post(rerank_url, headers=headers, json=data)
# Check the response
if response.status_code == 200:
print("Request successful!")
print(json.dumps(response.json(), indent=2))
else:
print(f"Request failed with status code: {response.status_code}")
print(response.text)
if __name__ == "__main__":
args = parse_args()
main(args)
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa: E501
""" """
Example online usage of Score API. Example online usage of Score API.
Run `vllm serve <model> --runner pooling` to start up the server in vLLM. Run `vllm serve <model> --runner pooling` to start up the server in vLLM.
e.g.
vllm serve jinaai/jina-reranker-m0 --runner pooling
vllm serve Qwen/Qwen3-VL-Reranker-2B \
--runner pooling \
--max-model-len 4096 \
--hf_overrides '{"architectures": ["Qwen3VLForSequenceClassification"],"classifier_from_token": ["no", "yes"],"is_original_qwen3_reranker": true}' \
--chat-template examples/pooling/score/template/qwen3_vl_reranker.jinja
""" """
import argparse import argparse
import json
import pprint import pprint
import requests import requests
headers = {"accept": "application/json", "Content-Type": "application/json"}
def post_http_request(prompt: dict, api_url: str) -> requests.Response: text_1 = "slm markdown"
headers = {"User-Agent": "Test Client"} text_2 = {
response = requests.post(api_url, headers=headers, json=prompt) "content": [
return response {
"type": "image_url",
"image_url": {
"url": "https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/handelsblatt-preview.png"
},
},
{
"type": "image_url",
"image_url": {
"url": "https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/paper-11.png"
},
},
]
}
def parse_args(): def parse_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost") parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=8000) parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--model", type=str, default="jinaai/jina-reranker-m0")
return parser.parse_args() return parser.parse_args()
def main(args): def main(args):
api_url = f"http://{args.host}:{args.port}/score" base_url = f"http://{args.host}:{args.port}"
model_name = args.model models_url = base_url + "/v1/models"
score_url = base_url + "/score"
text_1 = "slm markdown"
text_2 = { response = requests.get(models_url, headers=headers)
"content": [ model = response.json()["data"][0]["id"]
{
"type": "image_url", prompt = {"model": model, "text_1": text_1, "text_2": text_2}
"image_url": { response = requests.post(score_url, headers=headers, json=prompt)
"url": "https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/handelsblatt-preview.png"
},
},
{
"type": "image_url",
"image_url": {
"url": "https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/paper-11.png"
},
},
]
}
prompt = {"model": model_name, "text_1": text_1, "text_2": text_2}
score_response = post_http_request(prompt=prompt, api_url=api_url)
print("\nPrompt when text_1 is string and text_2 is a image list:") print("\nPrompt when text_1 is string and text_2 is a image list:")
pprint.pprint(prompt) pprint.pprint(prompt)
print("\nScore Response:") print("\nScore Response:")
pprint.pprint(score_response.json()) print(json.dumps(response.json(), indent=2))
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment