"vllm/vscode:/vscode.git/clone" did not exist on "ab1091d5f2fc879ba9e62002f4d9eec013984d4d"
Unverified Commit 1cb39dbc authored by Christian Pinto's avatar Christian Pinto Committed by GitHub
Browse files

[Misc] IO Processor plugins for pooling models (#22820)


Signed-off-by: default avatarChristian Pinto <christian.pinto@ibm.com>
Signed-off-by: default avatarMax de Bayser <mbayser@br.ibm.com>
Co-authored-by: default avatarMax de Bayser <mbayser@br.ibm.com>
parent 437c3ce0
......@@ -4,7 +4,7 @@
import asyncio
import base64
import time
from collections.abc import AsyncGenerator
from collections.abc import AsyncGenerator, Sequence
from typing import Final, Literal, Optional, Union, cast
import jinja2
......@@ -13,19 +13,25 @@ import torch
from fastapi import Request
from typing_extensions import assert_never
from vllm.config import ModelConfig
from vllm.config import VllmConfig
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
from vllm.entrypoints.logger import RequestLogger
# yapf: disable
from vllm.entrypoints.openai.protocol import (ErrorResponse,
IOProcessorRequest,
IOProcessorResponse,
PoolingChatRequest,
PoolingCompletionRequest,
PoolingRequest, PoolingResponse,
PoolingResponseData, UsageInfo)
from vllm.entrypoints.openai.serving_engine import OpenAIServing
# yapf: enable
from vllm.entrypoints.openai.serving_engine import OpenAIServing, RequestPrompt
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.utils import _validate_truncation_size
from vllm.logger import init_logger
from vllm.outputs import PoolingOutput, PoolingRequestOutput
from vllm.plugins.io_processors import get_io_processor
from vllm.utils import merge_async_iterators
logger = init_logger(__name__)
......@@ -52,7 +58,7 @@ class OpenAIServingPooling(OpenAIServing):
def __init__(
self,
engine_client: EngineClient,
model_config: ModelConfig,
vllm_config: VllmConfig,
models: OpenAIServingModels,
*,
request_logger: Optional[RequestLogger],
......@@ -61,19 +67,21 @@ class OpenAIServingPooling(OpenAIServing):
log_error_stack: bool = False,
) -> None:
super().__init__(engine_client=engine_client,
model_config=model_config,
model_config=vllm_config.model_config,
models=models,
request_logger=request_logger,
log_error_stack=log_error_stack)
self.chat_template = chat_template
self.chat_template_content_format: Final = chat_template_content_format
io_processor_plugin = self.model_config.io_processor_plugin
self.io_processor = get_io_processor(vllm_config, io_processor_plugin)
async def create_pooling(
self,
request: PoolingRequest,
raw_request: Optional[Request] = None,
) -> Union[PoolingResponse, ErrorResponse]:
) -> Union[PoolingResponse, IOProcessorResponse, ErrorResponse]:
"""
See https://platform.openai.com/docs/api-reference/embeddings/create
for the API specification. This API mimics the OpenAI Embedding API.
......@@ -82,20 +90,13 @@ class OpenAIServingPooling(OpenAIServing):
if error_check_ret is not None:
return error_check_ret
encoding_format = request.encoding_format
if request.dimensions is not None:
return self.create_error_response(
"dimensions is currently not supported")
model_name = self._get_model_name(request.model)
request_id = f"pool-{self._base_request_id(raw_request)}"
created_time = int(time.time())
truncate_prompt_tokens = request.truncate_prompt_tokens
is_io_processor_request = isinstance(request, IOProcessorRequest)
try:
truncate_prompt_tokens = _validate_truncation_size(
self.max_model_len, truncate_prompt_tokens)
lora_request = self._maybe_get_adapters(request)
if self.model_config.skip_tokenizer_init:
......@@ -104,7 +105,32 @@ class OpenAIServingPooling(OpenAIServing):
tokenizer = await self.engine_client.get_tokenizer(lora_request
)
if isinstance(request, PoolingChatRequest):
if getattr(request, "dimensions", None) is not None:
return self.create_error_response(
"dimensions is currently not supported")
truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens",
None)
truncate_prompt_tokens = _validate_truncation_size(
self.max_model_len, truncate_prompt_tokens)
if is_io_processor_request:
if self.io_processor is None:
raise ValueError(
"No IOProcessor plugin installed. Please refer "
"to the documentation and to the "
"'prithvi_geospatial_mae_io_processor' "
"offline inference example for more details.")
validated_prompt = self.io_processor.parse_request(request)
engine_prompts = await self.io_processor.pre_process_async(
prompt=validated_prompt, request_id=request_id)
request_prompts: Sequence[RequestPrompt] = [
""
] * len(engine_prompts)
elif isinstance(request, PoolingChatRequest):
(
_,
request_prompts,
......@@ -122,7 +148,7 @@ class OpenAIServingPooling(OpenAIServing):
continue_final_message=False,
add_special_tokens=request.add_special_tokens,
)
else:
elif isinstance(request, PoolingCompletionRequest):
(request_prompts,
engine_prompts) = await self._preprocess_completion(
request,
......@@ -130,6 +156,9 @@ class OpenAIServingPooling(OpenAIServing):
request.input,
add_special_tokens=request.add_special_tokens,
)
else:
raise ValueError(
f"Unsupported request of type {type(request)}")
except (ValueError, TypeError, jinja2.TemplateError) as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
......@@ -171,6 +200,16 @@ class OpenAIServingPooling(OpenAIServing):
result_generator = merge_async_iterators(*generators)
if is_io_processor_request:
assert self.io_processor is not None
output = await self.io_processor.post_process_async(
model_output=result_generator,
request_id=request_id,
)
return self.io_processor.output_to_response(output)
assert isinstance(request,
(PoolingCompletionRequest, PoolingChatRequest))
num_prompts = len(engine_prompts)
# Non-streaming response
......@@ -190,7 +229,7 @@ class OpenAIServingPooling(OpenAIServing):
request_id,
created_time,
model_name,
encoding_format,
request.encoding_format,
)
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from .data import (DecoderOnlyInputs, EmbedsInputs, EmbedsPrompt,
from .data import (DataPrompt, DecoderOnlyInputs, EmbedsInputs, EmbedsPrompt,
EncoderDecoderInputs, ExplicitEncoderDecoderPrompt,
ProcessorInputs, PromptType, SingletonInputs,
SingletonPrompt, TextPrompt, TokenInputs, TokensPrompt,
......@@ -18,6 +18,7 @@ target model.
"""
__all__ = [
"DataPrompt",
"TextPrompt",
"TokensPrompt",
"PromptType",
......
......@@ -95,6 +95,16 @@ class EmbedsPrompt(TypedDict):
"""
class DataPrompt(TypedDict):
"""Represents generic inputs handled by IO processor plugins."""
data: Any
"""The input data"""
data_format: str
"""The input data format"""
SingletonPrompt = Union[str, TextPrompt, TokensPrompt, EmbedsPrompt]
"""
Set of possible schemas for a single prompt:
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import logging
from typing import Optional
from vllm.config import VllmConfig
from vllm.plugins import load_plugins_by_group
from vllm.plugins.io_processors.interface import IOProcessor
from vllm.utils import resolve_obj_by_qualname
logger = logging.getLogger(__name__)
def get_io_processor(
vllm_config: VllmConfig,
plugin_from_init: Optional[str] = None) -> IOProcessor | None:
# Input.Output processors are loaded as plugins under the
# 'vllm.io_processor_plugins' group. Similar to platform
# plugins, these plugins register a function that returns the class
# name for the processor to install.
if plugin_from_init:
model_plugin = plugin_from_init
else:
# A plugin can be specified via the model config
# Retrieve the model specific plugin if available
# This is using a custom field in the hf_config for the model
hf_config = vllm_config.model_config.hf_config.to_dict()
config_plugin = hf_config.get("io_processor_plugin")
model_plugin = config_plugin
if model_plugin is None:
logger.info("No IOProcessor plugins requested by the model")
return None
logger.debug("IOProcessor plugin to be loaded %s", model_plugin)
# Load all installed plugin in the group
multimodal_data_processor_plugins = \
load_plugins_by_group('vllm.io_processor_plugins')
loadable_plugins = {}
for name, func in multimodal_data_processor_plugins.items():
try:
assert callable(func)
processor_cls_qualname = func()
if processor_cls_qualname is not None:
loadable_plugins[name] = processor_cls_qualname
except Exception:
logger.warning("Failed to load plugin %s.", name, exc_info=True)
num_available_plugins = len(loadable_plugins.keys())
if num_available_plugins == 0:
raise ValueError("No IOProcessor plugins installed"
f" but one is required ({model_plugin}).")
if model_plugin not in loadable_plugins:
raise ValueError(
f"The model requires the '{model_plugin}' IO Processor plugin "
"but it is not installed. "
f"Available plugins: {list(loadable_plugins.keys())}")
activated_plugin_cls = loadable_plugins[model_plugin]
return resolve_obj_by_qualname(activated_plugin_cls)(vllm_config)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator, Sequence
from typing import Any, Generic, Optional, TypeVar, Union
from vllm.config import VllmConfig
from vllm.entrypoints.openai.protocol import IOProcessorResponse
from vllm.inputs.data import PromptType
from vllm.outputs import PoolingRequestOutput
IOProcessorInput = TypeVar('IOProcessorInput')
IOProcessorOutput = TypeVar('IOProcessorOutput')
class IOProcessor(ABC, Generic[IOProcessorInput, IOProcessorOutput]):
def __init__(self, vllm_config: VllmConfig):
self.vllm_config = vllm_config
@abstractmethod
def pre_process(
self,
prompt: IOProcessorInput,
request_id: Optional[str] = None,
**kwargs,
) -> Union[PromptType, Sequence[PromptType]]:
raise NotImplementedError
async def pre_process_async(
self,
prompt: IOProcessorInput,
request_id: Optional[str] = None,
**kwargs,
) -> Union[PromptType, Sequence[PromptType]]:
return self.pre_process(prompt, request_id, **kwargs)
@abstractmethod
def post_process(self,
model_output: Sequence[PoolingRequestOutput],
request_id: Optional[str] = None,
**kwargs) -> IOProcessorOutput:
raise NotImplementedError
async def post_process_async(
self,
model_output: AsyncGenerator[tuple[int, PoolingRequestOutput]],
request_id: Optional[str] = None,
**kwargs,
) -> IOProcessorOutput:
collected_output = [item async for i, item in model_output]
return self.post_process(collected_output, request_id, **kwargs)
@abstractmethod
def parse_request(self, request: Any) -> IOProcessorInput:
raise NotImplementedError
@abstractmethod
def output_to_response(
self, plugin_output: IOProcessorOutput) -> IOProcessorResponse:
raise NotImplementedError
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment