"benchmarks/vscode:/vscode.git/clone" did not exist on "a776a48b1c753645c547b735ab647867c98a9b0c"
Unverified Commit 1cb39dbc authored by Christian Pinto's avatar Christian Pinto Committed by GitHub
Browse files

[Misc] IO Processor plugins for pooling models (#22820)


Signed-off-by: default avatarChristian Pinto <christian.pinto@ibm.com>
Signed-off-by: default avatarMax de Bayser <mbayser@br.ibm.com>
Co-authored-by: default avatarMax de Bayser <mbayser@br.ibm.com>
parent 437c3ce0
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
import asyncio import asyncio
import base64 import base64
import time import time
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator, Sequence
from typing import Final, Literal, Optional, Union, cast from typing import Final, Literal, Optional, Union, cast
import jinja2 import jinja2
...@@ -13,19 +13,25 @@ import torch ...@@ -13,19 +13,25 @@ import torch
from fastapi import Request from fastapi import Request
from typing_extensions import assert_never from typing_extensions import assert_never
from vllm.config import ModelConfig from vllm.config import VllmConfig
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.logger import RequestLogger
# yapf: disable
from vllm.entrypoints.openai.protocol import (ErrorResponse, from vllm.entrypoints.openai.protocol import (ErrorResponse,
IOProcessorRequest,
IOProcessorResponse,
PoolingChatRequest, PoolingChatRequest,
PoolingCompletionRequest,
PoolingRequest, PoolingResponse, PoolingRequest, PoolingResponse,
PoolingResponseData, UsageInfo) PoolingResponseData, UsageInfo)
from vllm.entrypoints.openai.serving_engine import OpenAIServing # yapf: enable
from vllm.entrypoints.openai.serving_engine import OpenAIServing, RequestPrompt
from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.utils import _validate_truncation_size from vllm.entrypoints.utils import _validate_truncation_size
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import PoolingOutput, PoolingRequestOutput from vllm.outputs import PoolingOutput, PoolingRequestOutput
from vllm.plugins.io_processors import get_io_processor
from vllm.utils import merge_async_iterators from vllm.utils import merge_async_iterators
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -52,7 +58,7 @@ class OpenAIServingPooling(OpenAIServing): ...@@ -52,7 +58,7 @@ class OpenAIServingPooling(OpenAIServing):
def __init__( def __init__(
self, self,
engine_client: EngineClient, engine_client: EngineClient,
model_config: ModelConfig, vllm_config: VllmConfig,
models: OpenAIServingModels, models: OpenAIServingModels,
*, *,
request_logger: Optional[RequestLogger], request_logger: Optional[RequestLogger],
...@@ -61,19 +67,21 @@ class OpenAIServingPooling(OpenAIServing): ...@@ -61,19 +67,21 @@ class OpenAIServingPooling(OpenAIServing):
log_error_stack: bool = False, log_error_stack: bool = False,
) -> None: ) -> None:
super().__init__(engine_client=engine_client, super().__init__(engine_client=engine_client,
model_config=model_config, model_config=vllm_config.model_config,
models=models, models=models,
request_logger=request_logger, request_logger=request_logger,
log_error_stack=log_error_stack) log_error_stack=log_error_stack)
self.chat_template = chat_template self.chat_template = chat_template
self.chat_template_content_format: Final = chat_template_content_format self.chat_template_content_format: Final = chat_template_content_format
io_processor_plugin = self.model_config.io_processor_plugin
self.io_processor = get_io_processor(vllm_config, io_processor_plugin)
async def create_pooling( async def create_pooling(
self, self,
request: PoolingRequest, request: PoolingRequest,
raw_request: Optional[Request] = None, raw_request: Optional[Request] = None,
) -> Union[PoolingResponse, ErrorResponse]: ) -> Union[PoolingResponse, IOProcessorResponse, ErrorResponse]:
""" """
See https://platform.openai.com/docs/api-reference/embeddings/create See https://platform.openai.com/docs/api-reference/embeddings/create
for the API specification. This API mimics the OpenAI Embedding API. for the API specification. This API mimics the OpenAI Embedding API.
...@@ -82,20 +90,13 @@ class OpenAIServingPooling(OpenAIServing): ...@@ -82,20 +90,13 @@ class OpenAIServingPooling(OpenAIServing):
if error_check_ret is not None: if error_check_ret is not None:
return error_check_ret return error_check_ret
encoding_format = request.encoding_format
if request.dimensions is not None:
return self.create_error_response(
"dimensions is currently not supported")
model_name = self._get_model_name(request.model) model_name = self._get_model_name(request.model)
request_id = f"pool-{self._base_request_id(raw_request)}" request_id = f"pool-{self._base_request_id(raw_request)}"
created_time = int(time.time()) created_time = int(time.time())
truncate_prompt_tokens = request.truncate_prompt_tokens is_io_processor_request = isinstance(request, IOProcessorRequest)
try: try:
truncate_prompt_tokens = _validate_truncation_size(
self.max_model_len, truncate_prompt_tokens)
lora_request = self._maybe_get_adapters(request) lora_request = self._maybe_get_adapters(request)
if self.model_config.skip_tokenizer_init: if self.model_config.skip_tokenizer_init:
...@@ -104,7 +105,32 @@ class OpenAIServingPooling(OpenAIServing): ...@@ -104,7 +105,32 @@ class OpenAIServingPooling(OpenAIServing):
tokenizer = await self.engine_client.get_tokenizer(lora_request tokenizer = await self.engine_client.get_tokenizer(lora_request
) )
if isinstance(request, PoolingChatRequest): if getattr(request, "dimensions", None) is not None:
return self.create_error_response(
"dimensions is currently not supported")
truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens",
None)
truncate_prompt_tokens = _validate_truncation_size(
self.max_model_len, truncate_prompt_tokens)
if is_io_processor_request:
if self.io_processor is None:
raise ValueError(
"No IOProcessor plugin installed. Please refer "
"to the documentation and to the "
"'prithvi_geospatial_mae_io_processor' "
"offline inference example for more details.")
validated_prompt = self.io_processor.parse_request(request)
engine_prompts = await self.io_processor.pre_process_async(
prompt=validated_prompt, request_id=request_id)
request_prompts: Sequence[RequestPrompt] = [
""
] * len(engine_prompts)
elif isinstance(request, PoolingChatRequest):
( (
_, _,
request_prompts, request_prompts,
...@@ -122,7 +148,7 @@ class OpenAIServingPooling(OpenAIServing): ...@@ -122,7 +148,7 @@ class OpenAIServingPooling(OpenAIServing):
continue_final_message=False, continue_final_message=False,
add_special_tokens=request.add_special_tokens, add_special_tokens=request.add_special_tokens,
) )
else: elif isinstance(request, PoolingCompletionRequest):
(request_prompts, (request_prompts,
engine_prompts) = await self._preprocess_completion( engine_prompts) = await self._preprocess_completion(
request, request,
...@@ -130,6 +156,9 @@ class OpenAIServingPooling(OpenAIServing): ...@@ -130,6 +156,9 @@ class OpenAIServingPooling(OpenAIServing):
request.input, request.input,
add_special_tokens=request.add_special_tokens, add_special_tokens=request.add_special_tokens,
) )
else:
raise ValueError(
f"Unsupported request of type {type(request)}")
except (ValueError, TypeError, jinja2.TemplateError) as e: except (ValueError, TypeError, jinja2.TemplateError) as e:
logger.exception("Error in preprocessing prompt inputs") logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e)) return self.create_error_response(str(e))
...@@ -171,6 +200,16 @@ class OpenAIServingPooling(OpenAIServing): ...@@ -171,6 +200,16 @@ class OpenAIServingPooling(OpenAIServing):
result_generator = merge_async_iterators(*generators) result_generator = merge_async_iterators(*generators)
if is_io_processor_request:
assert self.io_processor is not None
output = await self.io_processor.post_process_async(
model_output=result_generator,
request_id=request_id,
)
return self.io_processor.output_to_response(output)
assert isinstance(request,
(PoolingCompletionRequest, PoolingChatRequest))
num_prompts = len(engine_prompts) num_prompts = len(engine_prompts)
# Non-streaming response # Non-streaming response
...@@ -190,7 +229,7 @@ class OpenAIServingPooling(OpenAIServing): ...@@ -190,7 +229,7 @@ class OpenAIServingPooling(OpenAIServing):
request_id, request_id,
created_time, created_time,
model_name, model_name,
encoding_format, request.encoding_format,
) )
except asyncio.CancelledError: except asyncio.CancelledError:
return self.create_error_response("Client disconnected") return self.create_error_response("Client disconnected")
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from .data import (DecoderOnlyInputs, EmbedsInputs, EmbedsPrompt, from .data import (DataPrompt, DecoderOnlyInputs, EmbedsInputs, EmbedsPrompt,
EncoderDecoderInputs, ExplicitEncoderDecoderPrompt, EncoderDecoderInputs, ExplicitEncoderDecoderPrompt,
ProcessorInputs, PromptType, SingletonInputs, ProcessorInputs, PromptType, SingletonInputs,
SingletonPrompt, TextPrompt, TokenInputs, TokensPrompt, SingletonPrompt, TextPrompt, TokenInputs, TokensPrompt,
...@@ -18,6 +18,7 @@ target model. ...@@ -18,6 +18,7 @@ target model.
""" """
__all__ = [ __all__ = [
"DataPrompt",
"TextPrompt", "TextPrompt",
"TokensPrompt", "TokensPrompt",
"PromptType", "PromptType",
......
...@@ -95,6 +95,16 @@ class EmbedsPrompt(TypedDict): ...@@ -95,6 +95,16 @@ class EmbedsPrompt(TypedDict):
""" """
class DataPrompt(TypedDict):
"""Represents generic inputs handled by IO processor plugins."""
data: Any
"""The input data"""
data_format: str
"""The input data format"""
SingletonPrompt = Union[str, TextPrompt, TokensPrompt, EmbedsPrompt] SingletonPrompt = Union[str, TextPrompt, TokensPrompt, EmbedsPrompt]
""" """
Set of possible schemas for a single prompt: Set of possible schemas for a single prompt:
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import logging
from typing import Optional
from vllm.config import VllmConfig
from vllm.plugins import load_plugins_by_group
from vllm.plugins.io_processors.interface import IOProcessor
from vllm.utils import resolve_obj_by_qualname
logger = logging.getLogger(__name__)
def get_io_processor(
vllm_config: VllmConfig,
plugin_from_init: Optional[str] = None) -> IOProcessor | None:
# Input.Output processors are loaded as plugins under the
# 'vllm.io_processor_plugins' group. Similar to platform
# plugins, these plugins register a function that returns the class
# name for the processor to install.
if plugin_from_init:
model_plugin = plugin_from_init
else:
# A plugin can be specified via the model config
# Retrieve the model specific plugin if available
# This is using a custom field in the hf_config for the model
hf_config = vllm_config.model_config.hf_config.to_dict()
config_plugin = hf_config.get("io_processor_plugin")
model_plugin = config_plugin
if model_plugin is None:
logger.info("No IOProcessor plugins requested by the model")
return None
logger.debug("IOProcessor plugin to be loaded %s", model_plugin)
# Load all installed plugin in the group
multimodal_data_processor_plugins = \
load_plugins_by_group('vllm.io_processor_plugins')
loadable_plugins = {}
for name, func in multimodal_data_processor_plugins.items():
try:
assert callable(func)
processor_cls_qualname = func()
if processor_cls_qualname is not None:
loadable_plugins[name] = processor_cls_qualname
except Exception:
logger.warning("Failed to load plugin %s.", name, exc_info=True)
num_available_plugins = len(loadable_plugins.keys())
if num_available_plugins == 0:
raise ValueError("No IOProcessor plugins installed"
f" but one is required ({model_plugin}).")
if model_plugin not in loadable_plugins:
raise ValueError(
f"The model requires the '{model_plugin}' IO Processor plugin "
"but it is not installed. "
f"Available plugins: {list(loadable_plugins.keys())}")
activated_plugin_cls = loadable_plugins[model_plugin]
return resolve_obj_by_qualname(activated_plugin_cls)(vllm_config)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator, Sequence
from typing import Any, Generic, Optional, TypeVar, Union
from vllm.config import VllmConfig
from vllm.entrypoints.openai.protocol import IOProcessorResponse
from vllm.inputs.data import PromptType
from vllm.outputs import PoolingRequestOutput
IOProcessorInput = TypeVar('IOProcessorInput')
IOProcessorOutput = TypeVar('IOProcessorOutput')
class IOProcessor(ABC, Generic[IOProcessorInput, IOProcessorOutput]):
def __init__(self, vllm_config: VllmConfig):
self.vllm_config = vllm_config
@abstractmethod
def pre_process(
self,
prompt: IOProcessorInput,
request_id: Optional[str] = None,
**kwargs,
) -> Union[PromptType, Sequence[PromptType]]:
raise NotImplementedError
async def pre_process_async(
self,
prompt: IOProcessorInput,
request_id: Optional[str] = None,
**kwargs,
) -> Union[PromptType, Sequence[PromptType]]:
return self.pre_process(prompt, request_id, **kwargs)
@abstractmethod
def post_process(self,
model_output: Sequence[PoolingRequestOutput],
request_id: Optional[str] = None,
**kwargs) -> IOProcessorOutput:
raise NotImplementedError
async def post_process_async(
self,
model_output: AsyncGenerator[tuple[int, PoolingRequestOutput]],
request_id: Optional[str] = None,
**kwargs,
) -> IOProcessorOutput:
collected_output = [item async for i, item in model_output]
return self.post_process(collected_output, request_id, **kwargs)
@abstractmethod
def parse_request(self, request: Any) -> IOProcessorInput:
raise NotImplementedError
@abstractmethod
def output_to_response(
self, plugin_output: IOProcessorOutput) -> IOProcessorResponse:
raise NotImplementedError
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment