Unverified Commit c9a1923b authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Plugin] Simplify IO Processor Plugin interface (#34236)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent b482f71e
......@@ -14,8 +14,26 @@ IOProcessorOutput = TypeVar("IOProcessorOutput")
class IOProcessor(ABC, Generic[IOProcessorInput, IOProcessorOutput]):
def __init__(self, vllm_config: VllmConfig):
super().__init__()
self.vllm_config = vllm_config
@abstractmethod
def parse_data(self, data: object) -> IOProcessorInput:
raise NotImplementedError
def merge_sampling_params(
self,
params: SamplingParams | None = None,
) -> SamplingParams:
return params or SamplingParams()
def merge_pooling_params(
self,
params: PoolingParams | None = None,
) -> PoolingParams:
return params or PoolingParams()
@abstractmethod
def pre_process(
self,
......@@ -55,29 +73,13 @@ class IOProcessor(ABC, Generic[IOProcessorInput, IOProcessorOutput]):
[(i, item) async for i, item in model_output], key=lambda output: output[0]
)
collected_output = [output[1] for output in sorted_output]
return self.post_process(collected_output, request_id, **kwargs)
@abstractmethod
def parse_request(self, request: Any) -> IOProcessorInput:
raise NotImplementedError
def validate_or_generate_params(
self, params: SamplingParams | PoolingParams | None = None
) -> SamplingParams | PoolingParams:
return params or PoolingParams()
@abstractmethod
def output_to_response(
self, plugin_output: IOProcessorOutput
) -> IOProcessorResponse:
raise NotImplementedError
return self.post_process(collected_output, request_id=request_id, **kwargs)
```
The `parse_request` method is used for validating the user prompt and converting it into the input expected by the `pre_process`/`pre_process_async` methods.
The `parse_data` method is used for validating the user data and converting it into the input expected by the `pre_process*` methods.
The `merge_sampling_params` and `merge_pooling_params` methods merge input `SamplingParams` or `PoolingParams` (if any) with the default one.
The `pre_process*` methods take the validated plugin input to generate vLLM's model prompts for regular inference.
The `post_process*` methods take `PoolingRequestOutput` objects as input and generate a custom plugin output.
The `validate_or_generate_params` method is used for validating with the plugin any `SamplingParameters`/`PoolingParameters` received with the user request, or to generate new ones if none are specified. The function always returns the validated/generated parameters.
The `output_to_response` method is used only for online serving and converts the plugin output to the `IOProcessorResponse` type that is then returned by the API Server. The implementation of the `/pooling` serving endpoint is available here [vllm/entrypoints/openai/serving_pooling.py](../../vllm/entrypoints/pooling/pooling/serving.py).
An example implementation of a plugin that enables generating geotiff images with the PrithviGeospatialMAE model is available [here](https://github.com/IBM/terratorch/tree/main/terratorch/vllm/plugins/segmentation). Please, also refer to our online ([examples/pooling/plugin/prithvi_geospatial_mae_online.py](../../examples/pooling/plugin/prithvi_geospatial_mae_online.py)) and offline ([examples/pooling/plugin/prithvi_geospatial_mae_io_processor.py](../../examples/pooling/plugin/prithvi_geospatial_mae_io_processor.py)) inference examples.
......
......@@ -18,18 +18,10 @@ from einops import rearrange
from terratorch.datamodules import Sen1Floods11NonGeoDataModule
from vllm.config import VllmConfig
from vllm.entrypoints.pooling.pooling.protocol import (
IOProcessorRequest,
IOProcessorResponse,
)
from vllm.inputs.data import PromptType
from vllm.logger import init_logger
from vllm.outputs import PoolingRequestOutput
from vllm.plugins.io_processors.interface import (
IOProcessor,
IOProcessorInput,
IOProcessorOutput,
)
from vllm.plugins.io_processors.interface import IOProcessor
from .types import DataModuleConfig, ImagePrompt, ImageRequestOutput
......@@ -227,7 +219,7 @@ def load_image(
return imgs, temporal_coords, location_coords, metas
class PrithviMultimodalDataProcessor(IOProcessor):
class PrithviMultimodalDataProcessor(IOProcessor[ImagePrompt, ImageRequestOutput]):
indices = [0, 1, 2, 3, 4, 5]
def __init__(self, vllm_config: VllmConfig):
......@@ -251,34 +243,15 @@ class PrithviMultimodalDataProcessor(IOProcessor):
self.requests_cache: dict[str, dict[str, Any]] = {}
self.indices = DEFAULT_INPUT_INDICES
def parse_request(self, request: Any) -> IOProcessorInput:
if type(request) is dict:
image_prompt = ImagePrompt(**request)
return image_prompt
if isinstance(request, IOProcessorRequest):
if not hasattr(request, "data"):
raise ValueError("missing 'data' field in OpenAIBaseModel Request")
def parse_data(self, data: object) -> ImagePrompt:
if isinstance(data, dict):
return ImagePrompt(**data)
request_data = request.data
if type(request_data) is dict:
return ImagePrompt(**request_data)
else:
raise ValueError("Unable to parse the request data")
raise ValueError("Unable to parse request")
def output_to_response(
self, plugin_output: IOProcessorOutput
) -> IOProcessorResponse:
return IOProcessorResponse(
request_id=plugin_output.request_id,
data=plugin_output,
)
raise ValueError("Prompt data should be an `ImagePrompt`")
def pre_process(
self,
prompt: IOProcessorInput,
prompt: ImagePrompt,
request_id: str | None = None,
**kwargs,
) -> PromptType | Sequence[PromptType]:
......@@ -364,7 +337,7 @@ class PrithviMultimodalDataProcessor(IOProcessor):
model_output: Sequence[PoolingRequestOutput],
request_id: str | None = None,
**kwargs,
) -> IOProcessorOutput:
) -> ImageRequestOutput:
pred_imgs_list = []
if request_id and (request_id in self.requests_cache):
......@@ -409,5 +382,7 @@ class PrithviMultimodalDataProcessor(IOProcessor):
)
return ImageRequestOutput(
type=out_format, format="tiff", data=out_data, request_id=request_id
type=out_format,
format="tiff",
data=out_data,
)
......@@ -38,9 +38,6 @@ class ImagePrompt(BaseModel):
"""
MultiModalPromptType = ImagePrompt
class ImageRequestOutput(BaseModel):
"""
The output data of an image request to vLLM.
......@@ -54,4 +51,3 @@ class ImageRequestOutput(BaseModel):
type: Literal["path", "b64_json"]
format: str
data: str
request_id: str | None = None
......@@ -75,9 +75,7 @@ async def test_prithvi_mae_plugin_online(
# verify the output is formatted as expected for this plugin
plugin_data = parsed_response.data
assert all(
plugin_data.get(attr) for attr in ["type", "format", "data", "request_id"]
)
assert all(plugin_data.get(attr) for attr in ["type", "format", "data"])
# We just check that the output is a valid base64 string.
# Raises an exception and fails the test if the string is corrupted.
......@@ -110,9 +108,7 @@ def test_prithvi_mae_plugin_offline(vllm_runner, model_name: str):
output = pooler_output[0].outputs
# verify the output is formatted as expected for this plugin
assert all(
hasattr(output, attr) for attr in ["type", "format", "data", "request_id"]
)
assert all(hasattr(output, attr) for attr in ["type", "format", "data"])
# We just check that the output is a valid base64 string.
# Raises an exception and fails the test if the string is corrupted.
......
......@@ -85,7 +85,6 @@ from vllm.tasks import PoolingTask
from vllm.tokenizers import TokenizerLike
from vllm.tokenizers.mistral import MistralTokenizer
from vllm.usage.usage_lib import UsageContext
from vllm.utils.collection_utils import as_iter, is_list_of
from vllm.utils.counter import Counter
from vllm.v1.engine.llm_engine import LLMEngine
from vllm.v1.sample.logits_processor import LogitsProcessor
......@@ -95,6 +94,7 @@ if TYPE_CHECKING:
logger = init_logger(__name__)
_P = TypeVar("_P", bound=SamplingParams | PoolingParams | None)
_R = TypeVar("_R", default=Any)
......@@ -1056,9 +1056,7 @@ class LLM:
dict(truncate_prompt_tokens=truncate_prompt_tokens),
)
io_processor_prompt = False
if isinstance(prompts, dict) and "data" in prompts:
io_processor_prompt = True
if use_io_processor := (isinstance(prompts, dict) and "data" in prompts):
if self.io_processor is None:
raise ValueError(
"No IOProcessor plugin installed. Please refer "
......@@ -1068,40 +1066,42 @@ class LLM:
)
# Validate the request data is valid for the loaded plugin
validated_prompt = self.io_processor.parse_request(prompts)
validated_prompt = self.io_processor.parse_data(prompts)
# obtain the actual model prompts from the pre-processor
prompts = self.io_processor.pre_process(prompt=validated_prompt)
prompts_seq = prompt_to_seq(prompts)
if io_processor_prompt:
assert self.io_processor is not None
if is_list_of(pooling_params, PoolingParams):
validated_pooling_params: list[PoolingParams] = []
for param in as_iter(pooling_params):
validated_pooling_params.append(
self.io_processor.validate_or_generate_params(param)
params_seq: Sequence[PoolingParams] = [
self.io_processor.merge_pooling_params(param)
for param in self._params_to_seq(
pooling_params,
len(prompts_seq),
)
pooling_params = validated_pooling_params
]
for p in params_seq:
if p.task is None:
p.task = "plugin"
else:
assert not isinstance(pooling_params, Sequence)
pooling_params = self.io_processor.validate_or_generate_params(
pooling_params
)
if pooling_params is None:
# Use default pooling params.
pooling_params = PoolingParams()
for param in as_iter(pooling_params):
prompts_seq = prompt_to_seq(prompts)
params_seq = self._params_to_seq(pooling_params, len(prompts_seq))
for param in params_seq:
if param.task is None:
param.task = pooling_task
elif param.task != pooling_task:
msg = f"You cannot overwrite {param.task=!r} with {pooling_task=!r}!"
msg = (
f"You cannot overwrite {param.task=!r} with {pooling_task=!r}!"
)
raise ValueError(msg)
outputs = self._run_completion(
prompts=prompts,
params=pooling_params,
prompts=prompts_seq,
params=params_seq,
use_tqdm=use_tqdm,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
......@@ -1111,12 +1111,10 @@ class LLM:
outputs, PoolingRequestOutput
)
if io_processor_prompt:
if use_io_processor:
# get the post-processed model outputs
assert self.io_processor is not None
processed_outputs = self.io_processor.post_process(
model_output=model_outputs
)
processed_outputs = self.io_processor.post_process(model_outputs)
return [
PoolingRequestOutput[Any](
......@@ -1662,11 +1660,9 @@ class LLM:
def _params_to_seq(
self,
params: SamplingParams
| PoolingParams
| Sequence[SamplingParams | PoolingParams],
params: _P | Sequence[_P],
num_requests: int,
) -> Sequence[SamplingParams | PoolingParams]:
) -> Sequence[_P]:
if isinstance(params, Sequence):
if len(params) != num_requests:
raise ValueError(
......
......@@ -100,9 +100,6 @@ class IOProcessorRequest(PoolingBasicRequestMixin, EncodingRequestMixin, Generic
data: T
task: PoolingTask = "plugin"
def to_pooling_params(self):
return PoolingParams(task=self.task)
class IOProcessorResponse(OpenAIBaseModel, Generic[T]):
request_id: str | None = None
......
......@@ -85,7 +85,6 @@ class OpenAIServingPooling(OpenAIServing):
request_id = f"pool-{self._base_request_id(raw_request)}"
created_time = int(time.time())
is_io_processor_request = isinstance(request, IOProcessorRequest)
try:
lora_request = self._maybe_get_adapters(request)
......@@ -95,7 +94,7 @@ class OpenAIServingPooling(OpenAIServing):
)
engine_prompts: Sequence[PromptType | TokPrompt]
if is_io_processor_request:
if use_io_processor := isinstance(request, IOProcessorRequest):
if self.io_processor is None:
raise ValueError(
"No IOProcessor plugin installed. Please refer "
......@@ -104,7 +103,7 @@ class OpenAIServingPooling(OpenAIServing):
"offline inference example for more details."
)
validated_prompt = self.io_processor.parse_request(request)
validated_prompt = self.io_processor.parse_data(request.data)
raw_prompts = await self.io_processor.pre_process_async(
prompt=validated_prompt, request_id=request_id
......@@ -141,13 +140,18 @@ class OpenAIServingPooling(OpenAIServing):
# Schedule the request and get the result generator.
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
try:
if is_io_processor_request:
assert self.io_processor is not None and isinstance(
request, IOProcessorRequest
)
pooling_params = self.io_processor.validate_or_generate_params()
if use_io_processor:
assert self.io_processor is not None
pooling_params = self.io_processor.merge_pooling_params()
if pooling_params.task is None:
pooling_params.task = "plugin"
tokenization_kwargs: dict[str, Any] = {}
else:
pooling_params = request.to_pooling_params()
pooling_params = request.to_pooling_params() # type: ignore
tok_params = request.build_tok_params(self.model_config) # type: ignore
tokenization_kwargs = tok_params.get_encode_kwargs()
for i, engine_prompt in enumerate(engine_prompts):
request_id_item = f"{request_id}-{i}"
......@@ -165,12 +169,6 @@ class OpenAIServingPooling(OpenAIServing):
else await self._get_trace_headers(raw_request.headers)
)
if is_io_processor_request:
tokenization_kwargs: dict[str, Any] = {}
else:
tok_params = request.build_tok_params(self.model_config) # type: ignore
tokenization_kwargs = tok_params.get_encode_kwargs()
generator = self.engine_client.encode(
engine_prompt,
pooling_params,
......@@ -187,13 +185,31 @@ class OpenAIServingPooling(OpenAIServing):
result_generator = merge_async_iterators(*generators)
if is_io_processor_request:
if use_io_processor:
assert self.io_processor is not None
output = await self.io_processor.post_process_async(
model_output=result_generator,
result_generator,
request_id=request_id,
)
return self.io_processor.output_to_response(output)
if callable(
output_to_response := getattr(
self.io_processor, "output_to_response", None
)
):
logger.warning_once(
"`IOProcessor.output_to_response` is deprecated. To ensure "
"consistency between offline and online APIs, "
"`IOProcessorResponse` will become a transparent wrapper "
"around output data from v0.19 onwards.",
)
if hasattr(output, "request_id") and output.request_id is None:
output.request_id = request_id # type: ignore
return output_to_response(output) # type: ignore
return IOProcessorResponse(request_id=request_id, data=output)
assert isinstance(request, (PoolingCompletionRequest, PoolingChatRequest))
num_prompts = len(engine_prompts)
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import warnings
from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator, Sequence
from typing import Any, Generic, TypeVar
from typing import Generic, TypeVar
from vllm.config import VllmConfig
from vllm.entrypoints.pooling.pooling.protocol import IOProcessorResponse
from vllm.inputs.data import PromptType
from vllm.outputs import PoolingRequestOutput
from vllm.pooling_params import PoolingParams
......@@ -18,8 +17,68 @@ IOProcessorOutput = TypeVar("IOProcessorOutput")
class IOProcessor(ABC, Generic[IOProcessorInput, IOProcessorOutput]):
def __init__(self, vllm_config: VllmConfig):
super().__init__()
self.vllm_config = vllm_config
def parse_data(self, data: object) -> IOProcessorInput:
if callable(parse_request := getattr(self, "parse_request", None)):
warnings.warn(
"`parse_request` has been renamed to `parse_data`. "
"Please update your IO Processor Plugin to use the new name. "
"The old name will be removed in v0.19.",
DeprecationWarning,
stacklevel=2,
)
return parse_request(data) # type: ignore
raise NotImplementedError
def merge_sampling_params(
self,
params: SamplingParams | None = None,
) -> SamplingParams:
if callable(
validate_or_generate_params := getattr(
self, "validate_or_generate_params", None
)
):
warnings.warn(
"`validate_or_generate_params` has been split into "
"`merge_sampling_params` and `merge_pooling_params`."
"Please update your IO Processor Plugin to use the new methods. "
"The old name will be removed in v0.19.",
DeprecationWarning,
stacklevel=2,
)
return validate_or_generate_params(params) # type: ignore
return params or SamplingParams()
def merge_pooling_params(
self,
params: PoolingParams | None = None,
) -> PoolingParams:
if callable(
validate_or_generate_params := getattr(
self, "validate_or_generate_params", None
)
):
warnings.warn(
"`validate_or_generate_params` has been split into "
"`merge_sampling_params` and `merge_pooling_params`."
"Please update your IO Processor Plugin to use the new methods. "
"The old name will be removed in v0.19.",
DeprecationWarning,
stacklevel=2,
)
return validate_or_generate_params(params) # type: ignore
return params or PoolingParams(task="plugin")
@abstractmethod
def pre_process(
self,
......@@ -59,19 +118,4 @@ class IOProcessor(ABC, Generic[IOProcessorInput, IOProcessorOutput]):
[(i, item) async for i, item in model_output], key=lambda output: output[0]
)
collected_output = [output[1] for output in sorted_output]
return self.post_process(collected_output, request_id, **kwargs)
@abstractmethod
def parse_request(self, request: Any) -> IOProcessorInput:
raise NotImplementedError
def validate_or_generate_params(
self, params: SamplingParams | PoolingParams | None = None
) -> SamplingParams | PoolingParams:
return params or PoolingParams()
@abstractmethod
def output_to_response(
self, plugin_output: IOProcessorOutput
) -> IOProcessorResponse:
raise NotImplementedError
return self.post_process(collected_output, request_id=request_id, **kwargs)
......@@ -51,12 +51,6 @@ def as_list(maybe_list: Iterable[T]) -> list[T]:
return maybe_list if isinstance(maybe_list, list) else list(maybe_list)
def as_iter(obj: T | Iterable[T]) -> Iterable[T]:
if isinstance(obj, str) or not isinstance(obj, Iterable):
return [obj] # type: ignore[list-item]
return obj
def is_list_of(
value: object,
typ: type[T] | tuple[type[T], ...],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment