"vscode:/vscode.git/clone" did not exist on "51c1b9f1474da067c22f6bdc07bb14e9ce06bb1e"
Unverified Commit 617d55c0 authored by Ryan McCormick's avatar Ryan McCormick Committed by GitHub
Browse files

chore: remove deprecated examples/multimodal directory (#8141)


Signed-off-by: default avatarRyan McCormick <rmccormick@nvidia.com>
Co-authored-by: default avatarClaude Opus 4.6 (1M context) <noreply@anthropic.com>
parent 326a702d
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Utility functions for processing chat messages."""
def extract_user_text(messages) -> str:
"""Extract and concatenate text content from user messages."""
user_texts = []
for message in messages:
if message.role == "user":
# Collect all text content items from this user message
text_parts = []
for item in message.content:
if item.type == "text" and item.text:
text_parts.append(item.text)
# If this user message has text content, join it and add to user_texts
if text_parts:
user_texts.append("".join(text_parts))
if not user_texts:
raise ValueError("No text content found in user messages")
# Join all user turns with newline separator
return "\n".join(user_texts)
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import time
from typing import AsyncIterator, List, Optional, Protocol, Union, runtime_checkable
from vllm.config import ModelConfig, VllmConfig
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.chat_utils import ConversationMessage
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat
from vllm.entrypoints.openai.completion.protocol import CompletionRequest
from vllm.entrypoints.openai.completion.serving import OpenAIServingCompletion
from vllm.entrypoints.openai.engine.protocol import RequestResponseMetadata
from vllm.entrypoints.openai.models.protocol import BaseModelPath
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.serve.render.serving import OpenAIServingRender
from vllm.inputs.data import TokensPrompt
from vllm.renderers.registry import renderer_from_config
from vllm.sampling_params import SamplingParams
from vllm.tokenizers import TokenizerLike as AnyTokenizer
class StubEngineClient:
"""
Stub EngineClient for preprocessing-only use of OpenAIServingChat/Completion.
Provides the minimal attributes required by OpenAIServingModels.
"""
def __init__(self, model_config: ModelConfig):
self.model_config = model_config
self.renderer = renderer_from_config(VllmConfig(model_config=model_config))
self.input_processor = None
self.io_processor = None
@runtime_checkable
class ProcessMixInRequired(Protocol):
engine_args: AsyncEngineArgs
chat_processor: "ChatProcessor | None"
completions_processor: "CompletionsProcessor | None"
model_config: ModelConfig
default_sampling_params: SamplingParams
class ProcessMixIn(ProcessMixInRequired):
"""
Mixin for pre and post processing for vLLM
"""
engine_args: AsyncEngineArgs
chat_processor: "ChatProcessor | None"
completions_processor: "CompletionsProcessor | None"
model_config: ModelConfig
default_sampling_params: SamplingParams
def __init__(self):
pass
def _get_processor(
self, raw_request: Union[CompletionRequest, ChatCompletionRequest]
):
# Determine the processor type based on the request structure
return (
self.chat_processor
if isinstance(raw_request, ChatCompletionRequest)
else self.completions_processor
)
async def _parse_raw_request(
self, raw_request: Union[CompletionRequest, ChatCompletionRequest]
):
processor = self._get_processor(raw_request)
if processor is None:
raise RuntimeError("Processor has not been initialized")
request = processor.parse_raw_request(raw_request)
preprocess_result = await processor.preprocess(raw_request)
default_max_tokens = self.model_config.max_model_len - len(
preprocess_result.engine_prompt["prompt_token_ids"]
)
sampling_params = request.to_sampling_params(
default_max_tokens,
self.default_sampling_params,
)
return (
request,
preprocess_result.conversation,
preprocess_result.engine_prompt,
sampling_params,
)
async def _stream_response(self, request, generator, request_id, conversation):
processor = self._get_processor(request)
if processor is None:
raise RuntimeError("processor has not been initialized")
return processor.stream_response(
request,
generator,
request_id,
conversation,
)
class PreprocessResult:
def __init__(
self,
conversation: Optional[ConversationMessage],
engine_prompt: TokensPrompt,
):
self.conversation = conversation
self.engine_prompt = engine_prompt
class ChatProcessor:
def __init__(self, tokenizer: AnyTokenizer, model_config: ModelConfig):
self.tokenizer = tokenizer
self.model_config = model_config
# Create stub engine client and models for preprocessing-only usage
stub_engine = StubEngineClient(model_config)
serving_models = OpenAIServingModels(
engine_client=stub_engine,
base_model_paths=[
BaseModelPath(name=model_config.model, model_path=model_config.model)
],
)
serving_render = OpenAIServingRender(
model_config=model_config,
renderer=stub_engine.renderer,
io_processor=None,
model_registry=serving_models.registry,
request_logger=None,
chat_template=None,
chat_template_content_format="auto",
)
self.openai_serving = OpenAIServingChat(
engine_client=stub_engine,
models=serving_models,
response_role="assistant",
openai_serving_render=serving_render,
request_logger=None,
chat_template=None,
chat_template_content_format="auto",
)
def parse_raw_request(
self, raw_request: ChatCompletionRequest
) -> ChatCompletionRequest:
return ChatCompletionRequest.parse_obj(raw_request)
async def preprocess(self, raw_request: ChatCompletionRequest) -> PreprocessResult:
request = self.parse_raw_request(raw_request)
if not request.chat_template and not self.tokenizer.chat_template:
chat_template = "{% for message in messages %}{% if message['role'] == 'user' %}User: {{ message['content'] }}\n{% elif message['role'] == 'assistant' %}Assistant: {{ message['content'] }}\n{% endif %}{% endfor %}Assistant:"
else:
chat_template = request.chat_template or self.tokenizer.chat_template
(
conversation,
engine_prompts,
) = await self.openai_serving._preprocess_chat(
request,
request.messages,
default_template=chat_template,
default_template_content_format=self.openai_serving.chat_template_content_format,
default_template_kwargs=None,
tool_dicts=None,
tool_parser=None,
)
if not conversation or not engine_prompts:
raise ValueError(
"Preprocessing returned empty conversation or engine_prompts"
)
return PreprocessResult(conversation[0], engine_prompts[0])
async def stream_response(
self,
request: ChatCompletionRequest,
result_generator: AsyncIterator,
request_id: str,
conversation: List,
):
request_metadata = RequestResponseMetadata(request_id=request_id)
if request.stream:
# Handle streaming response
num_output_text_so_far = 0
async for raw_response in self.openai_serving.chat_completion_stream_generator(
request,
result_generator,
request_id,
request.model,
conversation,
self.tokenizer,
request_metadata,
):
if raw_response.startswith("data: [DONE]"):
yield raw_response
break
# Parse the response
response = json.loads(raw_response.lstrip("data: "))
# Process delta content to extract only new text
if "choices" in response and len(response["choices"]) > 0:
if "delta" in response["choices"][0]:
content = response["choices"][0]["delta"].get("content", "")
if content:
# Extract only the new part from the full content
new_content = content[num_output_text_so_far:]
response["choices"][0]["delta"]["content"] = new_content
num_output_text_so_far = len(content)
# Yield the processed response
yield f"data: {json.dumps(response)}\n\n"
else:
# Handle non-streaming response
# Collect all chunks into a single response
full_response = None
num_output_text_so_far = 0
async for raw_response in self.openai_serving.chat_completion_stream_generator(
request,
result_generator,
request_id,
request.model,
conversation,
self.tokenizer,
request_metadata,
):
if raw_response.startswith("data: [DONE]"):
break
response = json.loads(raw_response.lstrip("data: "))
if full_response is None:
# Initialize the full response structure
full_response = {
"id": response.get("id", ""),
"object": "chat.completion",
"created": int(time.time()),
"model": request.model,
"choices": [
{
"index": response.get("index", 0),
"message": {"role": "assistant", "content": ""},
"finish_reason": None,
}
],
}
# Concatenate content if it exists. Each delta contains the full text so far.
if "choices" in response and len(response["choices"]) > 0:
if "delta" in response["choices"][0]:
content = response["choices"][0]["delta"].get("content", "")
if content:
# Extract only the new part from the full content
new_content = content[num_output_text_so_far:]
full_response["choices"][0]["message"][
"content"
] += new_content
num_output_text_so_far = len(content)
# Update finish reason if present
if "finish_reason" in response["choices"][0]:
full_response["choices"][0]["finish_reason"] = response[
"choices"
][0]["finish_reason"]
if full_response is not None:
yield json.dumps(full_response)
class CompletionsProcessor:
def __init__(self, tokenizer: AnyTokenizer, model_config: ModelConfig):
self.tokenizer = tokenizer
self.model_config = model_config
# Create stub engine client and models for preprocessing-only usage
stub_engine = StubEngineClient(model_config)
serving_models = OpenAIServingModels(
engine_client=stub_engine,
base_model_paths=[
BaseModelPath(name=model_config.model, model_path=model_config.model)
],
)
serving_render = OpenAIServingRender(
model_config=model_config,
renderer=stub_engine.renderer,
io_processor=None,
model_registry=serving_models.registry,
request_logger=None,
chat_template=None,
chat_template_content_format="auto",
)
self.openai_serving = OpenAIServingCompletion(
engine_client=stub_engine,
models=serving_models,
openai_serving_render=serving_render,
request_logger=None,
)
def parse_raw_request(self, raw_request: CompletionRequest) -> CompletionRequest:
return CompletionRequest.parse_obj(raw_request)
async def preprocess(self, raw_request: CompletionRequest) -> PreprocessResult:
request = self.parse_raw_request(raw_request)
engine_prompts = await self.openai_serving._preprocess_completion(
request,
prompt_input=request.prompt,
prompt_embeds=getattr(request, "prompt_embeds", None),
)
if not engine_prompts:
raise ValueError("Preprocessing returned empty engine_prompts")
return PreprocessResult(None, engine_prompts[0])
async def stream_response(
self,
request: CompletionRequest,
result_generator: AsyncIterator,
request_id: str,
conversation: Optional[List[ConversationMessage]] = None,
):
request_metadata = RequestResponseMetadata(request_id=request_id)
if not request.stream:
raise ValueError("Only streaming responses are supported")
async for raw_response in self.openai_serving.completion_stream_generator(
request,
[], # engine_prompts (not needed for streaming output)
result_generator,
request_id,
int(time.time()), # created_time
request.model,
1, # num_prompts
self.tokenizer,
request_metadata,
):
if raw_response.startswith("data: [DONE]"):
break
response = json.loads(raw_response.lstrip("data: "))
yield response
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Any, Dict, Optional
import torch
from .model import SupportedModels
logger = logging.getLogger(__name__)
def get_qwen_image_features(
vision_encoder: torch.nn.Module, image_embeds: Dict[str, Any]
) -> torch.Tensor:
"""
Extract image features using Qwen-style vision encoder.
Args:
vision_encoder: The vision encoder model
image_embeds: Dictionary containing pixel values and grid information
Returns:
Processed image features tensor
Raises:
ValueError: If grid_thw is not provided for Qwen model
"""
pixel_values = image_embeds["pixel_values"].to(vision_encoder.device)
grid_thw = image_embeds.get("image_grid_thw", None)
if grid_thw is not None:
grid_thw = grid_thw.to(vision_encoder.device)
logger.debug(f"Qwen grid_thw shape: {grid_thw.shape}")
else:
raise ValueError("grid_thw is not provided")
return (
vision_encoder.get_image_features(pixel_values, grid_thw) # type: ignore
if grid_thw is not None
else vision_encoder.get_image_features(pixel_values) # type: ignore
)
def encode_image_embeddings(
model_name: str,
image_embeds: Dict[str, Any],
vision_encoder: torch.nn.Module,
projector: Optional[torch.nn.Module] = None,
) -> torch.Tensor:
"""
Encode image embeddings using the appropriate model-specific encoder.
Args:
model_name: The model identifier
image_embeds: Dictionary containing processed image data
vision_encoder: The vision encoder module
projector: The multimodal projector (required for LLaVA-style models)
Returns:
Encoded embeddings tensor with normalized shape
Raises:
ValueError: If projector is missing for LLaVA models
NotImplementedError: If model is not supported
"""
with torch.no_grad():
# Route through the correct encoder based on model
if model_name == SupportedModels.LLAVA_1_5_7B:
pixel_values = image_embeds["pixel_values"].to(vision_encoder.device)
vision_outputs = vision_encoder(pixel_values)
if projector is None:
raise ValueError(f"Projector not found for LLaVA model: {model_name}")
embeddings = projector(vision_outputs.last_hidden_state)
elif model_name == SupportedModels.QWEN_2_5_VL_7B:
embeddings = get_qwen_image_features(vision_encoder, image_embeds)
else:
raise NotImplementedError(f"Model not supported: {model_name}")
# Normalize output shape
if isinstance(embeddings, (tuple, list)):
embeddings = embeddings[0]
embeddings = embeddings.unsqueeze(0) if embeddings.ndim == 2 else embeddings
return embeddings
def get_encoder_components(
model_name: str, vision_model: torch.nn.Module
) -> tuple[Any, Optional[Any]]:
"""
Get the appropriate vision encoder and projector components for a given model.
Args:
model_name: The model identifier
vision_model: The loaded vision model
Returns:
Tuple of (vision_encoder, projector) where types depend on the model
Raises:
NotImplementedError: If model is not supported
"""
if model_name == SupportedModels.LLAVA_1_5_7B:
vision_encoder = vision_model.vision_tower
projector = getattr(vision_model, "multi_modal_projector", None)
return vision_encoder, projector
elif model_name == SupportedModels.QWEN_2_5_VL_7B:
vision_encoder = vision_model
projector = None
return vision_encoder, projector
else:
raise NotImplementedError(f"Model not supported: {model_name}")
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Optional
import httpx
logger = logging.getLogger(__name__)
# Global HTTP client instance
_global_http_client: Optional[httpx.AsyncClient] = None
def get_http_client(timeout: float = 60.0) -> httpx.AsyncClient:
"""
Get or create a shared HTTP client instance.
Args:
timeout: Timeout for HTTP requests
Returns:
Shared HTTP client instance
"""
global _global_http_client
if _global_http_client is None or _global_http_client.is_closed:
_global_http_client = httpx.AsyncClient(
timeout=timeout,
follow_redirects=True,
limits=httpx.Limits(max_keepalive_connections=20, max_connections=100),
)
logger.info(f"Shared HTTP client initialized with timeout={timeout}s")
return _global_http_client
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import base64
import binascii
import logging
from io import BytesIO
from urllib.parse import urlparse
import httpx
from PIL import Image
from .http_client import get_http_client
logger = logging.getLogger(__name__)
class ImageLoader:
CACHE_SIZE_MAXIMUM = 8
def __init__(
self, cache_size: int = CACHE_SIZE_MAXIMUM, http_timeout: float = 30.0
):
self._http_timeout = http_timeout
self._image_cache: dict[str, Image.Image] = {}
self._cache_queue: asyncio.Queue[str] = asyncio.Queue(maxsize=cache_size)
async def load_image(self, image_url: str) -> Image.Image:
parsed_url = urlparse(image_url)
# For HTTP(S) URLs, check cache first
if parsed_url.scheme in ("http", "https"):
image_url_lower = image_url.lower()
if image_url_lower in self._image_cache:
logger.debug(f"Image found in cache for URL: {image_url}")
return self._image_cache[image_url_lower]
try:
if parsed_url.scheme == "data":
# Parse data URL format: data:[<media type>][;base64],<data>
if not parsed_url.path.startswith("image/"):
raise ValueError("Data URL must be an image type")
# Split the path into media type and data
media_type, data = parsed_url.path.split(",", 1)
if ";base64" not in media_type:
raise ValueError("Data URL must be base64 encoded")
try:
image_bytes = base64.b64decode(data)
image_data = BytesIO(image_bytes)
except binascii.Error as e:
raise ValueError(f"Invalid base64 encoding: {e}")
elif parsed_url.scheme in ("http", "https"):
http_client = get_http_client(self._http_timeout)
response = await http_client.get(image_url)
response.raise_for_status()
if not response.content:
raise ValueError("Empty response content from image URL")
image_data = BytesIO(response.content)
else:
raise ValueError(f"Invalid image source scheme: {parsed_url.scheme}")
# PIL is sync, so offload to a thread to avoid blocking the event loop
# Restrict to supported formats to prevent PSD parsing (GHSA-cfh3-3jmp-rvhc)
image = await asyncio.to_thread(
Image.open, image_data, formats=["JPEG", "PNG", "WEBP"]
)
# Validate image format and convert to RGB
if image.format not in ("JPEG", "PNG", "WEBP"):
raise ValueError(f"Unsupported image format: {image.format}")
image_converted = image.convert("RGB")
# Cache HTTP(S) URLs
if parsed_url.scheme in ("http", "https"):
image_url_lower = image_url.lower()
# Cache the image for future use, and evict the oldest image if the cache is full
if self._cache_queue.full():
oldest_image_url = await self._cache_queue.get()
del self._image_cache[oldest_image_url]
self._image_cache[image_url_lower] = image_converted
await self._cache_queue.put(image_url_lower)
return image_converted
except httpx.HTTPError as e:
logger.error(f"HTTP error loading image: {e}")
raise
except Exception as e:
logger.error(f"Error loading image: {e}")
raise ValueError(f"Failed to load image: {e}")
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Any, Dict, List, Optional
import torch
from transformers import AutoModel
logger = logging.getLogger(__name__)
class SupportedModels:
"""Supported multimodal model identifiers"""
LLAVA_1_5_7B = "llava-hf/llava-1.5-7b-hf"
QWEN_2_5_VL_7B = "Qwen/Qwen2.5-VL-7B-Instruct"
LLAVA_NEXT_VIDEO_7B = "llava-hf/LLaVA-NeXT-Video-7B-hf"
QWEN_2_AUDIO_7B = "Qwen/Qwen2-Audio-7B-Instruct"
def load_vision_model(model_id: str) -> torch.nn.Module:
"""
Load a vision model from a HuggingFace model ID.
"""
model = AutoModel.from_pretrained(
model_id, device_map="auto", torch_dtype=torch.float16, trust_remote_code=True
)
return model
def construct_mm_data(
model: str,
embeddings_dtype: torch.dtype,
image_embeds: Optional[torch.Tensor] = None,
video_numpy: Optional[Any] = None,
image_grid_thw: Optional[List[Any]] = None,
audio_embeds: Optional[torch.Tensor] = None,
) -> Dict[str, torch.Tensor | Dict[str, Any]]:
"""Construct multimodal data for a vLLM request for models that require additional parameters alongside the embeddings"""
model_lower = model.lower()
if "audio" in model_lower:
audio_embeds = audio_embeds.to(torch.bfloat16)
assert audio_embeds.ndim == 2, "Audio embeddings must be 2D"
return {"audio": [audio_embeds]}
elif "video" in model_lower:
if video_numpy is None:
raise ValueError("No video frames provided.")
return {"video": video_numpy}
elif "qwen" in model_lower and "vl" in model_lower:
if image_embeds is None:
raise ValueError("No image embeddings provided.")
image_embeds = image_embeds.to(embeddings_dtype)
return _construct_qwen_image_data(image_embeds, image_grid_thw)
else:
# Default image handling for other models (e.g., LLAVA_1_5_7B)
if image_embeds is None:
raise ValueError("No image embeddings provided.")
image_embeds = image_embeds.to(embeddings_dtype)
return {"image": image_embeds}
def _construct_qwen_image_data(
image_embeds: torch.Tensor, image_grid_thw: Optional[List[Any]]
) -> Dict[str, Dict[str, torch.Tensor]]:
"""Construct image data specifically for Qwen models."""
if image_grid_thw is None or len(image_grid_thw) == 0:
raise ValueError("No image grid provided for Qwen model.")
grid_thw_tensor = torch.tensor(image_grid_thw)
return {
"image": {
"image_embeds": image_embeds.squeeze(0),
"image_grid_thw": grid_thw_tensor,
}
}
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from typing import Any, List, Literal, Optional, Tuple, Union
import msgspec
from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_validator
from pydantic_core import core_schema
from typing_extensions import NotRequired
from vllm.inputs.data import TokensPrompt
from vllm.logprobs import PromptLogprobs
from vllm.multimodal.inputs import MultiModalUUIDDict # noqa: F401
from vllm.outputs import CompletionOutput
from vllm.sampling_params import SamplingParams
from vllm.v1.metrics.stats import RequestStateStats
import dynamo.nixl_connect as connect
class Request(BaseModel):
prompt: str
sampling_params: dict
class Tokens(BaseModel):
tokens: list[int]
class PrefillRequest(Request):
request_id: str
class Response(BaseModel):
text: str
class PrefillResponse(BaseModel):
prefilled: bool
# Hack to override the type of multi_modal_data in TokensPrompt
# as pydantic doesn't understand generic types
# TokensPrompt is defined here: https://github.com/vllm-project/vllm/blob/a4c402a756fa3213caf9d2cde0e4ceb2d57727f2/vllm/inputs/data.py#L38
# multi_modal_data is defined here: https://github.com/vllm-project/vllm/blob/main/vllm/multimodal/inputs.py#L103
# ModalityData is defined here: https://github.com/vllm-project/vllm/blob/main/vllm/multimodal/inputs.py#L80
class PatchedTokensPrompt(TokensPrompt):
multi_modal_data: NotRequired[Optional[Any]] # type: ignore
# Monkey-patch the SamplingParams type to add a dummy core schema so pydantic can validate it
# Sampling params is a mspspec struct
# SamplingParams is defined here: https://github.com/vllm-project/vllm/blob/a4c402a756fa3213caf9d2cde0e4ceb2d57727f2/vllm/sampling_params.py#L88
SamplingParams.__get_pydantic_core_schema__ = classmethod(
lambda cls, source, handler: core_schema.any_schema()
)
class vLLMGenerateRequest(BaseModel):
"""
Serializable class of all the fields vLLM engine requires for inference
"""
engine_prompt: PatchedTokensPrompt
sampling_params: SamplingParams
request_id: str
prefix_hit_rate: Optional[float] = 0.0
@field_validator("sampling_params", mode="before")
@classmethod
def parse_sampling_params(cls, v: Any) -> SamplingParams:
if isinstance(v, str):
v = json.loads(v)
if isinstance(v, dict):
return SamplingParams(**v)
return v
@field_serializer("sampling_params")
def serialize_sampling_params(self, value: SamplingParams) -> dict[str, Any]:
"""Serialize SamplingParams using msgspec and return as dict."""
return json.loads(msgspec.json.encode(value))
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
class TextContent(BaseModel):
type: Literal["text"]
text: str
class ImageURLDetail(BaseModel):
url: str
class ImageContent(BaseModel):
type: Literal["image_url"]
image_url: ImageURLDetail
class AudioURLDetail(BaseModel):
url: str
class AudioContent(BaseModel):
type: Literal["audio_url"]
audio_url: AudioURLDetail
class VideoURLDetail(BaseModel):
url: str
class VideoContent(BaseModel):
type: Literal["video_url"]
video_url: VideoURLDetail
MessageContent = Union[TextContent, ImageContent, AudioContent, VideoContent]
class ChatMessage(BaseModel):
role: Literal["user", "system", "assistant"]
content: List[MessageContent]
class MultiModalRequest(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
model: str
messages: List[ChatMessage]
max_tokens: Optional[int] = None
temperature: Optional[float] = None
stream: Optional[bool] = True
stream_options: Optional[dict] = None
class MultiModalInput(BaseModel):
image_url: Optional[str] = None
video_url: Optional[str] = None
audio_url: Optional[str] = None
class vLLMMultimodalRequest(vLLMGenerateRequest):
model_config = ConfigDict(arbitrary_types_allowed=True)
# LoRA adapter name (matches the name used in load_lora)
model: Optional[str] = None
multimodal_input: Optional[MultiModalInput] = Field(default_factory=MultiModalInput)
image_grid_thw: Optional[List[Any]] = None
embeddings_shape: Optional[
Union[Tuple[int, int, int], Tuple[int, int, int, int], Tuple[int, int]]
] = None
serialized_request: Optional[connect.RdmaMetadata] = None
class MyRequestOutput(BaseModel):
"""
RequestOutput from vLLM is not serializable by default
https://github.com/vllm-project/vllm/blob/a4c402a756fa3213caf9d2cde0e4ceb2d57727f2/vllm/outputs.py#L85
This class is used to serialize the RequestOutput and any recursively defined types
We can do this because PromptLogprobs, RequestStateStats, and CompletionOutput are all serializable dataclasses
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
request_id: str
prompt: Optional[str] = None
prompt_token_ids: Optional[List[int]] = None
prompt_logprobs: Optional[PromptLogprobs] = None
outputs: List[CompletionOutput]
finished: bool
metrics: Optional[RequestStateStats] = None
kv_transfer_params: Optional[dict[str, Any]] = None
# lora_request: Optional[LoRARequest] = None
# encoder_prompt: Optional[str] = None
# encoder_prompt_token_ids: Optional[List[int]] = None
# num_cached_tokens: Optional[int] = None
# multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None
......@@ -521,16 +521,6 @@ vllm_configs = {
)
],
),
# TODO: Enable this test case when we have 4 GPUs runners.
# "multimodal_disagg": VLLMConfig(
# name="multimodal_disagg",
# directory=os.path.join(WORKSPACE_DIR, "examples/multimodal"),
# script_name="disagg.sh",
# marks=[pytest.mark.gpu_4, pytest.mark.vllm],
# model="llava-hf/llava-1.5-7b-hf",
# delayed_start=45,
# script_args=["--model", "llava-hf/llava-1.5-7b-hf"],
# ),
"completions_only": VLLMConfig(
name="completions_only",
directory=vllm_dir,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment