"vscode:/vscode.git/clone" did not exist on "c66c7f86aca956014d9ec6cc7a3e6001037e4655"
Unverified Commit 58fab50d authored by Russell Bryant's avatar Russell Bryant Committed by GitHub
Browse files

[Frontend] Require flag for loading text and image embeds (#27204)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent db6f28d8
...@@ -359,13 +359,19 @@ Full example: [examples/offline_inference/audio_language.py](../../examples/offl ...@@ -359,13 +359,19 @@ Full example: [examples/offline_inference/audio_language.py](../../examples/offl
To input pre-computed embeddings belonging to a data type (i.e. image, video, or audio) directly to the language model, To input pre-computed embeddings belonging to a data type (i.e. image, video, or audio) directly to the language model,
pass a tensor of shape `(num_items, feature_size, hidden_size of LM)` to the corresponding field of the multi-modal dictionary. pass a tensor of shape `(num_items, feature_size, hidden_size of LM)` to the corresponding field of the multi-modal dictionary.
You must enable this feature via `enable_mm_embeds=True`.
!!! warning
The vLLM engine may crash if incorrect shape of embeddings is passed.
Only enable this flag for trusted users!
??? code ??? code
```python ```python
from vllm import LLM from vllm import LLM
# Inference with image embeddings as input # Inference with image embeddings as input
llm = LLM(model="llava-hf/llava-1.5-7b-hf") llm = LLM(model="llava-hf/llava-1.5-7b-hf", enable_mm_embeds=True)
# Refer to the HuggingFace repo for the correct format to use # Refer to the HuggingFace repo for the correct format to use
prompt = "USER: <image>\nWhat is the content of this image?\nASSISTANT:" prompt = "USER: <image>\nWhat is the content of this image?\nASSISTANT:"
...@@ -397,7 +403,11 @@ For Qwen2-VL and MiniCPM-V, we accept additional parameters alongside the embedd ...@@ -397,7 +403,11 @@ For Qwen2-VL and MiniCPM-V, we accept additional parameters alongside the embedd
image_embeds = torch.load(...) image_embeds = torch.load(...)
# Qwen2-VL # Qwen2-VL
llm = LLM("Qwen/Qwen2-VL-2B-Instruct", limit_mm_per_prompt={"image": 4}) llm = LLM(
"Qwen/Qwen2-VL-2B-Instruct",
limit_mm_per_prompt={"image": 4},
enable_mm_embeds=True,
)
mm_data = { mm_data = {
"image": { "image": {
"image_embeds": image_embeds, "image_embeds": image_embeds,
...@@ -407,7 +417,12 @@ For Qwen2-VL and MiniCPM-V, we accept additional parameters alongside the embedd ...@@ -407,7 +417,12 @@ For Qwen2-VL and MiniCPM-V, we accept additional parameters alongside the embedd
} }
# MiniCPM-V # MiniCPM-V
llm = LLM("openbmb/MiniCPM-V-2_6", trust_remote_code=True, limit_mm_per_prompt={"image": 4}) llm = LLM(
"openbmb/MiniCPM-V-2_6",
trust_remote_code=True,
limit_mm_per_prompt={"image": 4},
enable_mm_embeds=True,
)
mm_data = { mm_data = {
"image": { "image": {
"image_embeds": image_embeds, "image_embeds": image_embeds,
...@@ -732,7 +747,13 @@ Full example: [examples/online_serving/openai_chat_completion_client_for_multimo ...@@ -732,7 +747,13 @@ Full example: [examples/online_serving/openai_chat_completion_client_for_multimo
### Embedding Inputs ### Embedding Inputs
To input pre-computed embeddings belonging to a data type (i.e. image, video, or audio) directly to the language model, To input pre-computed embeddings belonging to a data type (i.e. image, video, or audio) directly to the language model,
pass a tensor of shape to the corresponding field of the multi-modal dictionary. pass a tensor of shape `(num_items, feature_size, hidden_size of LM)` to the corresponding field of the multi-modal dictionary.
You must enable this feature via the `--enable-mm-embeds` flag in `vllm serve`.
!!! warning
The vLLM engine may crash if incorrect shape of embeddings is passed.
Only enable this flag for trusted users!
#### Image Embedding Inputs #### Image Embedding Inputs
......
...@@ -20,12 +20,16 @@ You can pass prompt embeddings from Hugging Face Transformers models to the `'p ...@@ -20,12 +20,16 @@ You can pass prompt embeddings from Hugging Face Transformers models to the `'p
## Online Serving ## Online Serving
Our OpenAI-compatible server accepts prompt embeddings inputs via the [Completions API](https://platform.openai.com/docs/api-reference/completions). Prompt embeddings inputs are added via a new `'prompt_embeds'` key in the JSON package. Our OpenAI-compatible server accepts prompt embeddings inputs via the [Completions API](https://platform.openai.com/docs/api-reference/completions). Prompt embeddings inputs are added via a new `'prompt_embeds'` key in the JSON package and are enabled by the `--enable-prompt-embeds` flag in `vllm serve`.
When a mixture of `'prompt_embeds'` and `'prompt'` inputs are provided in a single request, the prompt embeds are always returned first. When a mixture of `'prompt_embeds'` and `'prompt'` inputs are provided in a single request, the prompt embeds are always returned first.
Prompt embeddings are passed in as base64 encoded torch tensors. Prompt embeddings are passed in as base64 encoded torch tensors.
!!! warning
The vLLM engine may crash if incorrect shape of embeddings is passed.
Only enable this flag for trusted users!
### Transformers Inputs via OpenAI Client ### Transformers Inputs via OpenAI Client
First, launch the OpenAI-compatible server: First, launch the OpenAI-compatible server:
......
...@@ -49,6 +49,7 @@ class PrithviMAE: ...@@ -49,6 +49,7 @@ class PrithviMAE:
dtype="float16", dtype="float16",
enforce_eager=True, enforce_eager=True,
model_impl="terratorch", model_impl="terratorch",
enable_mm_embeds=True,
) )
def run(self, input_data, location_coords): def run(self, input_data, location_coords):
......
...@@ -38,6 +38,7 @@ def main(): ...@@ -38,6 +38,7 @@ def main():
max_num_seqs=32, max_num_seqs=32,
io_processor_plugin="prithvi_to_tiff", io_processor_plugin="prithvi_to_tiff",
model_impl="terratorch", model_impl="terratorch",
enable_mm_embeds=True,
) )
pooling_params = PoolingParams(task="token_classify", activation=False) pooling_params = PoolingParams(task="token_classify", activation=False)
......
...@@ -19,6 +19,7 @@ import requests ...@@ -19,6 +19,7 @@ import requests
# --task embed --trust-remote-code # --task embed --trust-remote-code
# --skip-tokenizer-init --enforce-eager # --skip-tokenizer-init --enforce-eager
# --io-processor-plugin prithvi_to_tiff # --io-processor-plugin prithvi_to_tiff
# --enable-mm-embeds
def main(): def main():
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest import pytest
import torch
from vllm import LLM from vllm import LLM
...@@ -12,8 +13,22 @@ def test_empty_prompt(): ...@@ -12,8 +13,22 @@ def test_empty_prompt():
llm.generate([""]) llm.generate([""])
@pytest.mark.skip_v1
def test_out_of_vocab_token(): def test_out_of_vocab_token():
llm = LLM(model="openai-community/gpt2", enforce_eager=True) llm = LLM(model="openai-community/gpt2", enforce_eager=True)
with pytest.raises(ValueError, match="out of vocabulary"): with pytest.raises(ValueError, match="out of vocabulary"):
llm.generate({"prompt_token_ids": [999999]}) llm.generate({"prompt_token_ids": [999999]})
def test_require_mm_embeds():
llm = LLM(
model="llava-hf/llava-1.5-7b-hf",
enforce_eager=True,
enable_mm_embeds=False,
)
with pytest.raises(ValueError, match="--enable-mm-embeds"):
llm.generate(
{
"prompt": "<image>",
"multi_modal_data": {"image": torch.empty(1, 1, 1)},
}
)
...@@ -292,3 +292,16 @@ async def test_prompt_logprobs_raises_error( ...@@ -292,3 +292,16 @@ async def test_prompt_logprobs_raises_error(
temperature=0.0, temperature=0.0,
extra_body={"prompt_embeds": encoded_embeds, "prompt_logprobs": True}, extra_body={"prompt_embeds": encoded_embeds, "prompt_logprobs": True},
) )
@pytest.mark.asyncio
async def test_empty_prompt_embeds(
client_with_prompt_embeds: openai.AsyncOpenAI,
) -> None:
await client_with_prompt_embeds.completions.create(
model=MODEL_NAME,
prompt="Hello",
max_tokens=5,
temperature=0.0,
extra_body={"prompt_embeds": []},
)
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import io import io
from unittest.mock import Mock
# imports for structured outputs tests # imports for structured outputs tests
import openai import openai
...@@ -10,7 +11,8 @@ import pytest ...@@ -10,7 +11,8 @@ import pytest
import regex as re import regex as re
import torch import torch
from vllm.entrypoints.renderer import BaseRenderer from vllm.config import ModelConfig
from vllm.entrypoints.renderer import CompletionRenderer
from ...utils import RemoteOpenAIServer from ...utils import RemoteOpenAIServer
...@@ -59,6 +61,10 @@ async def test_out_of_vocab_token_ids(): ...@@ -59,6 +61,10 @@ async def test_out_of_vocab_token_ids():
def test_load_prompt_embeds( def test_load_prompt_embeds(
dtype: torch.dtype, layout: torch.layout, seq_len: int, hidden_size: int dtype: torch.dtype, layout: torch.layout, seq_len: int, hidden_size: int
): ):
model_config = Mock(spec=ModelConfig)
model_config.enable_prompt_embeds = True
renderer = CompletionRenderer(model_config, tokenizer=None)
# construct arbitrary tensors of various dtypes, layouts, and sizes. # construct arbitrary tensors of various dtypes, layouts, and sizes.
# We need to check against different layouts to make sure that if a user # We need to check against different layouts to make sure that if a user
# uses sparse tensors to reduce the transmission size of prompt embeddings, # uses sparse tensors to reduce the transmission size of prompt embeddings,
...@@ -83,7 +89,7 @@ def test_load_prompt_embeds( ...@@ -83,7 +89,7 @@ def test_load_prompt_embeds(
buffer.seek(0) buffer.seek(0)
encoded_tensor = pybase64.b64encode(buffer.getvalue()) encoded_tensor = pybase64.b64encode(buffer.getvalue())
loaded_prompt_embeds = BaseRenderer.load_prompt_embeds(encoded_tensor) loaded_prompt_embeds = renderer.load_prompt_embeds(encoded_tensor)
assert len(loaded_prompt_embeds) == 1 assert len(loaded_prompt_embeds) == 1
loaded_tensor = loaded_prompt_embeds[0]["prompt_embeds"] loaded_tensor = loaded_prompt_embeds[0]["prompt_embeds"]
assert loaded_tensor.device.type == "cpu" assert loaded_tensor.device.type == "cpu"
...@@ -91,3 +97,22 @@ def test_load_prompt_embeds( ...@@ -91,3 +97,22 @@ def test_load_prompt_embeds(
torch.testing.assert_close( torch.testing.assert_close(
loaded_tensor, tensor.to("cpu").to_dense(), equal_nan=True loaded_tensor, tensor.to("cpu").to_dense(), equal_nan=True
) )
@pytest.mark.parametrize("dtype", [torch.float32])
@pytest.mark.parametrize("seq_len", [2])
@pytest.mark.parametrize("hidden_size", [2])
def test_disable_prompt_embeds(dtype: torch.dtype, seq_len: int, hidden_size: int):
model_config = Mock(spec=ModelConfig)
model_config.enable_prompt_embeds = False
renderer = CompletionRenderer(model_config, tokenizer=None)
tensor = torch.randn((seq_len, hidden_size), dtype=dtype)
buffer = io.BytesIO()
torch.save(tensor, buffer)
buffer.seek(0)
encoded_tensor = pybase64.b64encode(buffer.getvalue())
with pytest.raises(ValueError, match="--enable-prompt-embeds"):
renderer.load_prompt_embeds(encoded_tensor)
...@@ -15,30 +15,7 @@ MODEL_NAME = "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11" ...@@ -15,30 +15,7 @@ MODEL_NAME = "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11"
DTYPE = "float16" DTYPE = "float16"
@pytest.fixture(scope="module") def _terratorch_dummy_inputs(model_name: str):
def server():
args = [
"--runner",
"pooling",
# use half precision for speed and memory savings in CI environment
"--dtype",
DTYPE,
"--enforce-eager",
"--trust-remote-code",
"--skip-tokenizer-init",
"--max-num-seqs",
"32",
"--model-impl",
"terratorch",
]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_single_request(server: RemoteOpenAIServer, model_name: str):
pixel_values = torch.full((6, 512, 512), 1.0, dtype=torch.float16) pixel_values = torch.full((6, 512, 512), 1.0, dtype=torch.float16)
location_coords = torch.full((1, 2), 1.0, dtype=torch.float16) location_coords = torch.full((1, 2), 1.0, dtype=torch.float16)
...@@ -54,7 +31,7 @@ async def test_single_request(server: RemoteOpenAIServer, model_name: str): ...@@ -54,7 +31,7 @@ async def test_single_request(server: RemoteOpenAIServer, model_name: str):
binary_data = buffer_coord.read() binary_data = buffer_coord.read()
base64_coord_embedding = base64.b64encode(binary_data).decode("utf-8") base64_coord_embedding = base64.b64encode(binary_data).decode("utf-8")
prompt = { return {
"model": model_name, "model": model_name,
"additional_data": {"prompt_token_ids": [1]}, "additional_data": {"prompt_token_ids": [1]},
"encoding_format": "base64", "encoding_format": "base64",
...@@ -74,6 +51,28 @@ async def test_single_request(server: RemoteOpenAIServer, model_name: str): ...@@ -74,6 +51,28 @@ async def test_single_request(server: RemoteOpenAIServer, model_name: str):
], ],
} }
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_single_request(model_name: str):
args = [
"--runner",
"pooling",
# use half precision for speed and memory savings in CI environment
"--dtype",
DTYPE,
"--enforce-eager",
"--trust-remote-code",
"--max-num-seqs",
"32",
"--model-impl",
"terratorch",
"--skip-tokenizer-init",
"--enable-mm-embeds",
]
with RemoteOpenAIServer(MODEL_NAME, args) as server:
prompt = _terratorch_dummy_inputs(model_name)
# test single pooling # test single pooling
response = requests.post(server.url_for("pooling"), json=prompt) response = requests.post(server.url_for("pooling"), json=prompt)
response.raise_for_status() response.raise_for_status()
...@@ -81,5 +80,4 @@ async def test_single_request(server: RemoteOpenAIServer, model_name: str): ...@@ -81,5 +80,4 @@ async def test_single_request(server: RemoteOpenAIServer, model_name: str):
output = response.json()["data"][0]["data"] output = response.json()["data"][0]["data"]
np_response = np.frombuffer(base64.b64decode(output), dtype=np.float32) np_response = np.frombuffer(base64.b64decode(output), dtype=np.float32)
assert len(np_response) == 524288 assert len(np_response) == 524288
...@@ -73,6 +73,19 @@ def phi3v_model_config_mm_interleaved(): ...@@ -73,6 +73,19 @@ def phi3v_model_config_mm_interleaved():
) )
@pytest.fixture(scope="function")
def phi3v_model_config_image_embeds():
return ModelConfig(
PHI3V_MODEL_ID,
runner="generate",
trust_remote_code=True,
limit_mm_per_prompt={
"image": 2,
},
enable_mm_embeds=True,
)
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def phi3v_tokenizer(): def phi3v_tokenizer():
return get_tokenizer(PHI3V_MODEL_ID) return get_tokenizer(PHI3V_MODEL_ID)
...@@ -799,7 +812,7 @@ def test_parse_chat_messages_empty_pil_image_with_uuid( ...@@ -799,7 +812,7 @@ def test_parse_chat_messages_empty_pil_image_with_uuid(
def test_parse_chat_messages_empty_image_embeds_with_uuid( def test_parse_chat_messages_empty_image_embeds_with_uuid(
phi3v_model_config, phi3v_model_config_image_embeds,
phi3v_tokenizer, phi3v_tokenizer,
): ):
uuid = "abcd" uuid = "abcd"
...@@ -813,7 +826,7 @@ def test_parse_chat_messages_empty_image_embeds_with_uuid( ...@@ -813,7 +826,7 @@ def test_parse_chat_messages_empty_image_embeds_with_uuid(
], ],
} }
], ],
phi3v_model_config, phi3v_model_config_image_embeds,
phi3v_tokenizer, phi3v_tokenizer,
content_format="string", content_format="string",
) )
...@@ -832,7 +845,7 @@ def test_parse_chat_messages_empty_image_embeds_with_uuid( ...@@ -832,7 +845,7 @@ def test_parse_chat_messages_empty_image_embeds_with_uuid(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_parse_chat_messages_empty_image_embeds_with_uuid_async( async def test_parse_chat_messages_empty_image_embeds_with_uuid_async(
phi3v_model_config, phi3v_model_config_image_embeds,
phi3v_tokenizer, phi3v_tokenizer,
): ):
uuid = "abcd" uuid = "abcd"
...@@ -846,7 +859,7 @@ async def test_parse_chat_messages_empty_image_embeds_with_uuid_async( ...@@ -846,7 +859,7 @@ async def test_parse_chat_messages_empty_image_embeds_with_uuid_async(
], ],
} }
], ],
phi3v_model_config, phi3v_model_config_image_embeds,
phi3v_tokenizer, phi3v_tokenizer,
content_format="string", content_format="string",
) )
......
...@@ -17,6 +17,7 @@ from vllm.inputs.data import is_embeds_prompt ...@@ -17,6 +17,7 @@ from vllm.inputs.data import is_embeds_prompt
class MockModelConfig: class MockModelConfig:
max_model_len: int = 100 max_model_len: int = 100
encoder_config: dict | None = None encoder_config: dict | None = None
enable_prompt_embeds: bool = True
class MockTokenizerResult: class MockTokenizerResult:
......
...@@ -109,8 +109,7 @@ VLM_TEST_SETTINGS = { ...@@ -109,8 +109,7 @@ VLM_TEST_SETTINGS = {
limit_mm_per_prompt={"image": 4}, limit_mm_per_prompt={"image": 4},
) )
], ],
# TODO: Revert to "auto" when CPU backend can use torch > 2.6 vllm_runner_kwargs={"enable_mm_embeds": True},
dtype="bfloat16" if current_platform.is_cpu() else "auto",
marks=[pytest.mark.core_model, pytest.mark.cpu_model], marks=[pytest.mark.core_model, pytest.mark.cpu_model],
), ),
"qwen2_5_vl": VLMTestInfo( "qwen2_5_vl": VLMTestInfo(
......
...@@ -292,6 +292,7 @@ def run_embedding_input_test( ...@@ -292,6 +292,7 @@ def run_embedding_input_test(
tensor_parallel_size=tensor_parallel_size, tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend, distributed_executor_backend=distributed_executor_backend,
default_torch_num_threads=1, default_torch_num_threads=1,
enable_mm_embeds=True,
) as vllm_model: ) as vllm_model:
outputs_per_case_for_original_input = [ outputs_per_case_for_original_input = [
vllm_model.generate_greedy_logprobs( vllm_model.generate_greedy_logprobs(
......
...@@ -34,6 +34,7 @@ def _run_test( ...@@ -34,6 +34,7 @@ def _run_test(
dtype="half", dtype="half",
enforce_eager=True, enforce_eager=True,
skip_tokenizer_init=True, skip_tokenizer_init=True,
enable_mm_embeds=True,
# Limit the maximum number of sequences to avoid the # Limit the maximum number of sequences to avoid the
# test going OOM during the warmup run # test going OOM during the warmup run
max_num_seqs=32, max_num_seqs=32,
......
...@@ -104,6 +104,11 @@ def can_initialize( ...@@ -104,6 +104,11 @@ def can_initialize(
m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN") m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN")
if model_arch == "WhisperForConditionalGeneration": if model_arch == "WhisperForConditionalGeneration":
m.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn") m.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
extra_args = {}
if model_arch in ("PrithviGeoSpatialMAE", "Terratorch"):
extra_args["enable_mm_embeds"] = True
LLM( LLM(
model_info.default, model_info.default,
tokenizer=model_info.tokenizer, tokenizer=model_info.tokenizer,
...@@ -128,6 +133,7 @@ def can_initialize( ...@@ -128,6 +133,7 @@ def can_initialize(
else "vllm", else "vllm",
hf_overrides=hf_overrides_fn, hf_overrides=hf_overrides_fn,
max_num_seqs=model_info.max_num_seqs, max_num_seqs=model_info.max_num_seqs,
**extra_args,
) )
......
...@@ -32,6 +32,7 @@ def test_inference( ...@@ -32,6 +32,7 @@ def test_inference(
dtype="half", dtype="half",
enforce_eager=True, enforce_eager=True,
skip_tokenizer_init=True, skip_tokenizer_init=True,
enable_mm_embeds=True,
# Limit the maximum number of sequences to avoid the # Limit the maximum number of sequences to avoid the
# test going OOM during the warmup run # test going OOM during the warmup run
max_num_seqs=32, max_num_seqs=32,
......
...@@ -38,6 +38,7 @@ def server(): ...@@ -38,6 +38,7 @@ def server():
"prithvi_to_tiff", "prithvi_to_tiff",
"--model-impl", "--model-impl",
"terratorch", "terratorch",
"--enable-mm-embeds",
] ]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
......
...@@ -6,7 +6,6 @@ import openai # use the official client for correctness check ...@@ -6,7 +6,6 @@ import openai # use the official client for correctness check
import pytest import pytest
import pytest_asyncio import pytest_asyncio
import regex as re import regex as re
import requests
from openai import BadRequestError from openai import BadRequestError
from tests.utils import RemoteOpenAIServer from tests.utils import RemoteOpenAIServer
...@@ -686,17 +685,3 @@ async def test_invalid_grammar(client: openai.AsyncOpenAI, model_name: str): ...@@ -686,17 +685,3 @@ async def test_invalid_grammar(client: openai.AsyncOpenAI, model_name: str):
"structured_outputs": {"grammar": invalid_simplified_sql_grammar} "structured_outputs": {"grammar": invalid_simplified_sql_grammar}
}, },
) )
@pytest.mark.asyncio
async def test_completion_with_empty_prompt_embeds(client: openai.AsyncOpenAI) -> None:
"""Test completion with empty prompt embeds."""
payload: dict[str, object] = {"prompt": "Hello", "prompt_embeds": []}
headers: dict[str, str] = {"Content-Type": "application/json"}
# base_url = http://localhost:8000/v1/completions
response = requests.post(
f"{client.base_url}completions", headers=headers, json=payload
)
assert response.status_code == 200, (
f"Expected status code 200, got {response.status_code}. "
)
...@@ -32,6 +32,7 @@ def default_image_embeds_server_args() -> list[str]: ...@@ -32,6 +32,7 @@ def default_image_embeds_server_args() -> list[str]:
"--enforce-eager", "--enforce-eager",
"--limit-mm-per-prompt", "--limit-mm-per-prompt",
json.dumps({"image": MAXIMUM_IMAGES}), json.dumps({"image": MAXIMUM_IMAGES}),
"--enable-mm-embeds",
] ]
......
...@@ -232,8 +232,10 @@ class ModelConfig: ...@@ -232,8 +232,10 @@ class ModelConfig:
output will contain token ids.""" output will contain token ids."""
enable_prompt_embeds: bool = False enable_prompt_embeds: bool = False
"""If `True`, enables passing text embeddings as inputs via the """If `True`, enables passing text embeddings as inputs via the
`prompt_embeds` key. Note that enabling this will double the time required `prompt_embeds` key.
for graph compilation."""
WARNING: The vLLM engine may crash if incorrect shape of embeddings is passed.
Only enable this flag for trusted users!"""
served_model_name: str | list[str] | None = None served_model_name: str | list[str] | None = None
"""The model name(s) used in the API. If multiple names are provided, the """The model name(s) used in the API. If multiple names are provided, the
server will respond to any of the provided names. The model name in the server will respond to any of the provided names. The model name in the
...@@ -303,6 +305,7 @@ class ModelConfig: ...@@ -303,6 +305,7 @@ class ModelConfig:
"""Configuration for multimodal model. If `None`, this will be inferred """Configuration for multimodal model. If `None`, this will be inferred
from the architecture of `self.model`.""" from the architecture of `self.model`."""
limit_mm_per_prompt: InitVar[dict[str, int | dict[str, int]] | None] = None limit_mm_per_prompt: InitVar[dict[str, int | dict[str, int]] | None] = None
enable_mm_embeds: InitVar[bool | None] = None
media_io_kwargs: InitVar[dict[str, dict[str, Any]] | None] = None media_io_kwargs: InitVar[dict[str, dict[str, Any]] | None] = None
mm_processor_kwargs: InitVar[dict[str, Any] | None] = None mm_processor_kwargs: InitVar[dict[str, Any] | None] = None
mm_processor_cache_gb: InitVar[float | None] = None mm_processor_cache_gb: InitVar[float | None] = None
...@@ -421,6 +424,7 @@ class ModelConfig: ...@@ -421,6 +424,7 @@ class ModelConfig:
self, self,
# Multimodal config init vars # Multimodal config init vars
limit_mm_per_prompt: dict[str, int] | None, limit_mm_per_prompt: dict[str, int] | None,
enable_mm_embeds: bool | None,
media_io_kwargs: dict[str, dict[str, Any]] | None, media_io_kwargs: dict[str, dict[str, Any]] | None,
mm_processor_kwargs: dict[str, Any] | None, mm_processor_kwargs: dict[str, Any] | None,
mm_processor_cache_gb: float | None, mm_processor_cache_gb: float | None,
...@@ -731,6 +735,7 @@ class ModelConfig: ...@@ -731,6 +735,7 @@ class ModelConfig:
mm_config_kwargs = dict( mm_config_kwargs = dict(
limit_per_prompt=limit_mm_per_prompt, limit_per_prompt=limit_mm_per_prompt,
enable_mm_embeds=enable_mm_embeds,
media_io_kwargs=media_io_kwargs, media_io_kwargs=media_io_kwargs,
mm_processor_kwargs=mm_processor_kwargs, mm_processor_kwargs=mm_processor_kwargs,
mm_processor_cache_gb=mm_processor_cache_gb, mm_processor_cache_gb=mm_processor_cache_gb,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment