Unverified Commit 4570535e authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Model] CLIP Embedding Support (#26010)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 2a6dc67e
...@@ -829,6 +829,7 @@ The following table lists those that are tested in vLLM. ...@@ -829,6 +829,7 @@ The following table lists those that are tested in vLLM.
| Architecture | Models | Inputs | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) | | Architecture | Models | Inputs | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) |
|--------------|--------|--------|-------------------|----------------------|---------------------------|---------------------| |--------------|--------|--------|-------------------|----------------------|---------------------------|---------------------|
| `CLIPModel` | CLIP | T / I | `openai/clip-vit-base-patch32`, `openai/clip-vit-large-patch14`, etc. | | | ✅︎ |
| `LlavaNextForConditionalGeneration`<sup>C</sup> | LLaVA-NeXT-based | T / I | `royokong/e5-v` | | ✅︎ | ✅︎ | | `LlavaNextForConditionalGeneration`<sup>C</sup> | LLaVA-NeXT-based | T / I | `royokong/e5-v` | | ✅︎ | ✅︎ |
| `Phi3VForCausalLM`<sup>C</sup> | Phi-3-Vision-based | T + I | `TIGER-Lab/VLM2Vec-Full` | | ✅︎ | ✅︎ | | `Phi3VForCausalLM`<sup>C</sup> | Phi-3-Vision-based | T + I | `TIGER-Lab/VLM2Vec-Full` | | ✅︎ | ✅︎ |
| `*ForConditionalGeneration`<sup>C</sup>, `*ForCausalLM`<sup>C</sup>, etc. | Generative models | \* | N/A | \* | \* | \* | | `*ForConditionalGeneration`<sup>C</sup>, `*ForCausalLM`<sup>C</sup>, etc. | Generative models | \* | N/A | \* | \* | \* |
......
...@@ -58,6 +58,30 @@ class ModelRequestData(NamedTuple): ...@@ -58,6 +58,30 @@ class ModelRequestData(NamedTuple):
documents: Optional[ScoreMultiModalParam] = None documents: Optional[ScoreMultiModalParam] = None
def run_clip(query: Query) -> ModelRequestData:
if query["modality"] == "text":
prompt = query["text"]
image = None
elif query["modality"] == "image":
prompt = "" # For image input, make sure that the prompt text is empty
image = query["image"]
else:
modality = query["modality"]
raise ValueError(f"Unsupported query modality: '{modality}'")
engine_args = EngineArgs(
model="openai/clip-vit-base-patch32",
runner="pooling",
limit_mm_per_prompt={"image": 1},
)
return ModelRequestData(
engine_args=engine_args,
prompt=prompt,
image=image,
)
def run_e5_v(query: Query) -> ModelRequestData: def run_e5_v(query: Query) -> ModelRequestData:
llama3_template = "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n \n" # noqa: E501 llama3_template = "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n \n" # noqa: E501
...@@ -146,7 +170,8 @@ def run_vlm2vec_qwen2vl(query: Query) -> ModelRequestData: ...@@ -146,7 +170,8 @@ def run_vlm2vec_qwen2vl(query: Query) -> ModelRequestData:
processor = AutoProcessor.from_pretrained( processor = AutoProcessor.from_pretrained(
model_id, model_id,
# `min_pixels` and `max_pixels` are deprecated # `min_pixels` and `max_pixels` are deprecated for
# transformers `preprocessor_config.json`
size={"shortest_edge": 3136, "longest_edge": 12845056}, size={"shortest_edge": 3136, "longest_edge": 12845056},
) )
processor.chat_template = load_chat_template( processor.chat_template = load_chat_template(
...@@ -172,8 +197,10 @@ def run_vlm2vec_qwen2vl(query: Query) -> ModelRequestData: ...@@ -172,8 +197,10 @@ def run_vlm2vec_qwen2vl(query: Query) -> ModelRequestData:
model=merged_path, model=merged_path,
runner="pooling", runner="pooling",
max_model_len=4096, max_model_len=4096,
trust_remote_code=True, mm_processor_kwargs={
mm_processor_kwargs={"num_crops": 4}, "min_pixels": 3136,
"max_pixels": 12845056,
},
limit_mm_per_prompt={"image": 1}, limit_mm_per_prompt={"image": 1},
) )
...@@ -299,6 +326,7 @@ def run_score(model: str, modality: QueryModality, seed: Optional[int]): ...@@ -299,6 +326,7 @@ def run_score(model: str, modality: QueryModality, seed: Optional[int]):
model_example_map = { model_example_map = {
"clip": run_clip,
"e5_v": run_e5_v, "e5_v": run_e5_v,
"vlm2vec_phi3v": run_vlm2vec_phi3v, "vlm2vec_phi3v": run_vlm2vec_phi3v,
"vlm2vec_qwen2vl": run_vlm2vec_qwen2vl, "vlm2vec_qwen2vl": run_vlm2vec_qwen2vl,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa: E501 # ruff: noqa: E501
"""Example Python client for multimodal embedding API using vLLM API server """Example Python client for multimodal embedding API using vLLM API server.
NOTE:
start a supported multimodal embeddings model server with `vllm serve`, e.g. Refer to each `run_*` function for the command to run the server for that model.
vllm serve TIGER-Lab/VLM2Vec-Full \
--runner pooling \
--trust-remote-code \
--max-model-len 4096 \
--chat-template examples/template_vlm2vec_phi3v.jinja
""" """
import argparse import argparse
...@@ -47,7 +42,58 @@ def create_chat_embeddings( ...@@ -47,7 +42,58 @@ def create_chat_embeddings(
) )
def run_clip(client: OpenAI, model: str):
"""
Start the server using:
vllm serve openai/clip-vit-base-patch32 \
--runner pooling
"""
response = create_chat_embeddings(
client,
messages=[
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": image_url}},
],
}
],
model=model,
encoding_format="float",
)
print("Image embedding output:", response.data[0].embedding)
response = create_chat_embeddings(
client,
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "a photo of a cat"},
],
}
],
model=model,
encoding_format="float",
)
print("Text embedding output:", response.data[0].embedding)
def run_vlm2vec(client: OpenAI, model: str): def run_vlm2vec(client: OpenAI, model: str):
"""
Start the server using:
vllm serve TIGER-Lab/VLM2Vec-Full \
--runner pooling \
--trust-remote-code \
--max-model-len 4096 \
--chat-template examples/template_vlm2vec_phi3v.jinja
"""
response = create_chat_embeddings( response = create_chat_embeddings(
client, client,
messages=[ messages=[
...@@ -103,6 +149,15 @@ def run_vlm2vec(client: OpenAI, model: str): ...@@ -103,6 +149,15 @@ def run_vlm2vec(client: OpenAI, model: str):
def run_dse_qwen2_vl(client: OpenAI, model: str): def run_dse_qwen2_vl(client: OpenAI, model: str):
"""
Start the server using:
vllm serve MrLight/dse-qwen2-2b-mrl-v1 \
--runner pooling \
--trust-remote-code \
--max-model-len 8192 \
--chat-template examples/template_dse_qwen2_vl.jinja
"""
response = create_chat_embeddings( response = create_chat_embeddings(
client, client,
messages=[ messages=[
...@@ -156,6 +211,7 @@ def run_dse_qwen2_vl(client: OpenAI, model: str): ...@@ -156,6 +211,7 @@ def run_dse_qwen2_vl(client: OpenAI, model: str):
model_example_map = { model_example_map = {
"clip": run_clip,
"vlm2vec": run_vlm2vec, "vlm2vec": run_vlm2vec,
"dse_qwen2_vl": run_dse_qwen2_vl, "dse_qwen2_vl": run_dse_qwen2_vl,
} }
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from transformers import CLIPModel
from ....conftest import IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner
from ...utils import check_embeddings_close
HF_TEXT_PROMPTS = [
"a photo of a stop sign",
"a photo of a cherry blossom",
]
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
"stop_sign": "",
"cherry_blossom": "",
})
MODELS = ["openai/clip-vit-base-patch32"]
def _run_test(
hf_runner: type[HfRunner],
vllm_runner: type[VllmRunner],
input_texts: list[str],
input_images: PromptImageInput,
model: str,
*,
dtype: str,
) -> None:
# NOTE: take care of the order. run vLLM first, and then run HF.
# vLLM needs a fresh new process without cuda initialization.
# if we run HF first, the cuda initialization will be done and it
# will hurt multiprocessing backend with fork method (the default method).
with vllm_runner(model,
runner="pooling",
dtype=dtype,
enforce_eager=True,
max_model_len=77) as vllm_model:
vllm_outputs = vllm_model.embed(input_texts, images=input_images)
with hf_runner(model, dtype=dtype, auto_cls=CLIPModel) as hf_model:
all_inputs = hf_model.get_inputs(input_texts, images=input_images)
all_outputs = []
for inputs in all_inputs:
if "pixel_values" in inputs:
inputs.pop("input_ids")
pooled_output = hf_model.model.get_image_features(
**hf_model.wrap_device(inputs)).squeeze(0)
else:
pooled_output = hf_model.model.get_text_features(
**hf_model.wrap_device(inputs)).squeeze(0)
all_outputs.append(pooled_output.tolist())
hf_outputs = all_outputs
check_embeddings_close(
embeddings_0_lst=hf_outputs,
embeddings_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
def test_models_text(
hf_runner,
vllm_runner,
image_assets,
model: str,
dtype: str,
) -> None:
input_texts_images = [(text, None) for text in HF_TEXT_PROMPTS]
input_texts = [text for text, _ in input_texts_images]
input_images = [image for _, image in input_texts_images]
_run_test(
hf_runner,
vllm_runner,
input_texts,
input_images, # type: ignore
model,
dtype=dtype,
)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
def test_models_image(
hf_runner,
vllm_runner,
image_assets,
model: str,
dtype: str,
) -> None:
input_texts_images = [
(text, asset.pil_image)
for text, asset in zip(HF_IMAGE_PROMPTS, image_assets)
]
input_texts = [text for text, _ in input_texts_images]
input_images = [image for _, image in input_texts_images]
_run_test(
hf_runner,
vllm_runner,
input_texts,
input_images,
model,
dtype=dtype,
)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
def test_models_text_image_no_crash(
vllm_runner,
image_assets,
model: str,
dtype: str,
) -> None:
texts = [HF_TEXT_PROMPTS[0]]
images = [image_assets[0].pil_image]
with vllm_runner(model,
runner="pooling",
dtype=dtype,
enforce_eager=True,
max_model_len=77) as vllm_model:
with pytest.raises(ValueError, match="not both"):
vllm_model.embed(texts, images=images)
# Should still be able to run subsequent requests
vllm_model.embed(texts)
vllm_model.embed([""], images=images)
...@@ -389,6 +389,7 @@ _EMBEDDING_EXAMPLE_MODELS = { ...@@ -389,6 +389,7 @@ _EMBEDDING_EXAMPLE_MODELS = {
"RobertaForMaskedLM": _HfExamplesInfo("sentence-transformers/all-roberta-large-v1"), # noqa: E501 "RobertaForMaskedLM": _HfExamplesInfo("sentence-transformers/all-roberta-large-v1"), # noqa: E501
"XLMRobertaModel": _HfExamplesInfo("intfloat/multilingual-e5-small"), # noqa: E501 "XLMRobertaModel": _HfExamplesInfo("intfloat/multilingual-e5-small"), # noqa: E501
# [Multimodal] # [Multimodal]
"CLIPModel": _HfExamplesInfo("openai/clip-vit-base-patch32"),
"LlavaNextForConditionalGeneration": _HfExamplesInfo("royokong/e5-v"), "LlavaNextForConditionalGeneration": _HfExamplesInfo("royokong/e5-v"),
"Phi3VForCausalLM": _HfExamplesInfo("TIGER-Lab/VLM2Vec-Full", "Phi3VForCausalLM": _HfExamplesInfo("TIGER-Lab/VLM2Vec-Full",
trust_remote_code=True), trust_remote_code=True),
...@@ -687,7 +688,11 @@ class HfExampleModels: ...@@ -687,7 +688,11 @@ class HfExampleModels:
return self.hf_models.keys() return self.hf_models.keys()
def get_hf_info(self, model_arch: str) -> _HfExamplesInfo: def get_hf_info(self, model_arch: str) -> _HfExamplesInfo:
try:
return self.hf_models[model_arch] return self.hf_models[model_arch]
except KeyError:
raise ValueError(f"No example model defined for {model_arch}; "
f"please update this file.") from None
def find_hf_info(self, model_id: str) -> _HfExamplesInfo: def find_hf_info(self, model_id: str) -> _HfExamplesInfo:
for info in self.hf_models.values(): for info in self.hf_models.values():
...@@ -699,7 +704,8 @@ class HfExampleModels: ...@@ -699,7 +704,8 @@ class HfExampleModels:
if any(extra == model_id for extra in info.extras.values()): if any(extra == model_id for extra in info.extras.values()):
return info return info
raise ValueError(f"No example model defined for {model_id}") raise ValueError(f"No example model defined for {model_id}; "
f"please update this file.")
HF_EXAMPLE_MODELS = HfExampleModels(_EXAMPLE_MODELS) HF_EXAMPLE_MODELS = HfExampleModels(_EXAMPLE_MODELS)
......
...@@ -417,12 +417,16 @@ class MultiHeadAttention(nn.Module): ...@@ -417,12 +417,16 @@ class MultiHeadAttention(nn.Module):
head_size: int, head_size: int,
scale: float, scale: float,
num_kv_heads: Optional[int] = None, num_kv_heads: Optional[int] = None,
): # This has no effect, it is only here to make it easier to swap
# between Attention and MultiHeadAttention
prefix: str = "",
) -> None:
super().__init__() super().__init__()
self.num_heads = num_heads self.num_heads = num_heads
self.head_size = head_size self.head_size = head_size
self.scale = scale self.scale = scale
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
self.layer_name = prefix
assert self.num_heads % self.num_kv_heads == 0, \ assert self.num_heads % self.num_kv_heads == 0, \
f"num_heads ({self.num_heads}) is not " \ f"num_heads ({self.num_heads}) is not " \
......
...@@ -351,7 +351,7 @@ class BertModel(nn.Module, SupportsQuant): ...@@ -351,7 +351,7 @@ class BertModel(nn.Module, SupportsQuant):
prefix=f"{prefix}.encoder") prefix=f"{prefix}.encoder")
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embeddings(input_ids) return self.embeddings.word_embeddings(input_ids)
def forward( def forward(
self, self,
......
This diff is collapsed.
...@@ -187,6 +187,7 @@ _EMBEDDING_MODELS = { ...@@ -187,6 +187,7 @@ _EMBEDDING_MODELS = {
"TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"), "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
"XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"), "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
# [Multimodal] # [Multimodal]
"CLIPModel": ("clip", "CLIPEmbeddingModel"),
"LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501 "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"), "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
"Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501 "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501
......
...@@ -92,8 +92,10 @@ def get_vit_attn_backend(head_size: int, dtype: torch.dtype) -> _Backend: ...@@ -92,8 +92,10 @@ def get_vit_attn_backend(head_size: int, dtype: torch.dtype) -> _Backend:
return current_platform.get_vit_attn_backend(head_size, dtype) return current_platform.get_vit_attn_backend(head_size, dtype)
VisionFeatureSelectStrategyStr = Literal["class", "default", "full"]
VisionFeatureSelectStrategy = Union[ VisionFeatureSelectStrategy = Union[
Literal["class", "default", "full"], VisionFeatureSelectStrategyStr,
Callable[[torch.Tensor], torch.Tensor], Callable[[torch.Tensor], torch.Tensor],
] ]
...@@ -106,7 +108,7 @@ def _get_vision_feature_selector( ...@@ -106,7 +108,7 @@ def _get_vision_feature_selector(
# https://github.com/huggingface/transformers/blob/cd74917ffc3e8f84e4a886052c5ab32b7ac623cc/src/transformers/models/clip/modeling_clip.py#L762 # https://github.com/huggingface/transformers/blob/cd74917ffc3e8f84e4a886052c5ab32b7ac623cc/src/transformers/models/clip/modeling_clip.py#L762
if strategy == "class": if strategy == "class":
return lambda feats: feats[:, 0, :] return lambda feats: feats[:, :1, :]
# https://github.com/huggingface/transformers/blob/4a02bc7004285bdb12cc033e87ad2578ce2fa900/src/transformers/models/llava/modeling_llava.py#L196 # https://github.com/huggingface/transformers/blob/4a02bc7004285bdb12cc033e87ad2578ce2fa900/src/transformers/models/llava/modeling_llava.py#L196
if strategy == "default": if strategy == "default":
......
...@@ -33,6 +33,7 @@ def _get_minicpmv_chat_template_fallback( ...@@ -33,6 +33,7 @@ def _get_minicpmv_chat_template_fallback(
# yapf: disable # yapf: disable
_MODEL_TYPE_TO_CHAT_TEMPLATE_FALLBACK: dict[str, ChatTemplatePath] = { _MODEL_TYPE_TO_CHAT_TEMPLATE_FALLBACK: dict[str, ChatTemplatePath] = {
"blip-2": CHAT_TEMPLATES_DIR / "template_blip2.jinja", "blip-2": CHAT_TEMPLATES_DIR / "template_blip2.jinja",
"clip": CHAT_TEMPLATES_DIR / "template_basic.jinja",
"chameleon": CHAT_TEMPLATES_DIR / "template_basic.jinja", "chameleon": CHAT_TEMPLATES_DIR / "template_basic.jinja",
"deepseek_vl_v2": CHAT_TEMPLATES_DIR / "template_deepseek_vl2.jinja", "deepseek_vl_v2": CHAT_TEMPLATES_DIR / "template_deepseek_vl2.jinja",
"fuyu": CHAT_TEMPLATES_DIR / "template_fuyu.jinja", "fuyu": CHAT_TEMPLATES_DIR / "template_fuyu.jinja",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment