Unverified Commit c9a1923b authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Plugin] Simplify IO Processor Plugin interface (#34236)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent b482f71e
...@@ -14,8 +14,26 @@ IOProcessorOutput = TypeVar("IOProcessorOutput") ...@@ -14,8 +14,26 @@ IOProcessorOutput = TypeVar("IOProcessorOutput")
class IOProcessor(ABC, Generic[IOProcessorInput, IOProcessorOutput]): class IOProcessor(ABC, Generic[IOProcessorInput, IOProcessorOutput]):
def __init__(self, vllm_config: VllmConfig): def __init__(self, vllm_config: VllmConfig):
super().__init__()
self.vllm_config = vllm_config self.vllm_config = vllm_config
@abstractmethod
def parse_data(self, data: object) -> IOProcessorInput:
raise NotImplementedError
def merge_sampling_params(
self,
params: SamplingParams | None = None,
) -> SamplingParams:
return params or SamplingParams()
def merge_pooling_params(
self,
params: PoolingParams | None = None,
) -> PoolingParams:
return params or PoolingParams()
@abstractmethod @abstractmethod
def pre_process( def pre_process(
self, self,
...@@ -55,29 +73,13 @@ class IOProcessor(ABC, Generic[IOProcessorInput, IOProcessorOutput]): ...@@ -55,29 +73,13 @@ class IOProcessor(ABC, Generic[IOProcessorInput, IOProcessorOutput]):
[(i, item) async for i, item in model_output], key=lambda output: output[0] [(i, item) async for i, item in model_output], key=lambda output: output[0]
) )
collected_output = [output[1] for output in sorted_output] collected_output = [output[1] for output in sorted_output]
return self.post_process(collected_output, request_id, **kwargs) return self.post_process(collected_output, request_id=request_id, **kwargs)
@abstractmethod
def parse_request(self, request: Any) -> IOProcessorInput:
raise NotImplementedError
def validate_or_generate_params(
self, params: SamplingParams | PoolingParams | None = None
) -> SamplingParams | PoolingParams:
return params or PoolingParams()
@abstractmethod
def output_to_response(
self, plugin_output: IOProcessorOutput
) -> IOProcessorResponse:
raise NotImplementedError
``` ```
The `parse_request` method is used for validating the user prompt and converting it into the input expected by the `pre_process`/`pre_process_async` methods. The `parse_data` method is used for validating the user data and converting it into the input expected by the `pre_process*` methods.
The `merge_sampling_params` and `merge_pooling_params` methods merge input `SamplingParams` or `PoolingParams` (if any) with the default one.
The `pre_process*` methods take the validated plugin input to generate vLLM's model prompts for regular inference. The `pre_process*` methods take the validated plugin input to generate vLLM's model prompts for regular inference.
The `post_process*` methods take `PoolingRequestOutput` objects as input and generate a custom plugin output. The `post_process*` methods take `PoolingRequestOutput` objects as input and generate a custom plugin output.
The `validate_or_generate_params` method is used for validating with the plugin any `SamplingParameters`/`PoolingParameters` received with the user request, or to generate new ones if none are specified. The function always returns the validated/generated parameters.
The `output_to_response` method is used only for online serving and converts the plugin output to the `IOProcessorResponse` type that is then returned by the API Server. The implementation of the `/pooling` serving endpoint is available here [vllm/entrypoints/openai/serving_pooling.py](../../vllm/entrypoints/pooling/pooling/serving.py).
An example implementation of a plugin that enables generating geotiff images with the PrithviGeospatialMAE model is available [here](https://github.com/IBM/terratorch/tree/main/terratorch/vllm/plugins/segmentation). Please, also refer to our online ([examples/pooling/plugin/prithvi_geospatial_mae_online.py](../../examples/pooling/plugin/prithvi_geospatial_mae_online.py)) and offline ([examples/pooling/plugin/prithvi_geospatial_mae_io_processor.py](../../examples/pooling/plugin/prithvi_geospatial_mae_io_processor.py)) inference examples. An example implementation of a plugin that enables generating geotiff images with the PrithviGeospatialMAE model is available [here](https://github.com/IBM/terratorch/tree/main/terratorch/vllm/plugins/segmentation). Please, also refer to our online ([examples/pooling/plugin/prithvi_geospatial_mae_online.py](../../examples/pooling/plugin/prithvi_geospatial_mae_online.py)) and offline ([examples/pooling/plugin/prithvi_geospatial_mae_io_processor.py](../../examples/pooling/plugin/prithvi_geospatial_mae_io_processor.py)) inference examples.
......
...@@ -18,18 +18,10 @@ from einops import rearrange ...@@ -18,18 +18,10 @@ from einops import rearrange
from terratorch.datamodules import Sen1Floods11NonGeoDataModule from terratorch.datamodules import Sen1Floods11NonGeoDataModule
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.entrypoints.pooling.pooling.protocol import (
IOProcessorRequest,
IOProcessorResponse,
)
from vllm.inputs.data import PromptType from vllm.inputs.data import PromptType
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import PoolingRequestOutput from vllm.outputs import PoolingRequestOutput
from vllm.plugins.io_processors.interface import ( from vllm.plugins.io_processors.interface import IOProcessor
IOProcessor,
IOProcessorInput,
IOProcessorOutput,
)
from .types import DataModuleConfig, ImagePrompt, ImageRequestOutput from .types import DataModuleConfig, ImagePrompt, ImageRequestOutput
...@@ -227,7 +219,7 @@ def load_image( ...@@ -227,7 +219,7 @@ def load_image(
return imgs, temporal_coords, location_coords, metas return imgs, temporal_coords, location_coords, metas
class PrithviMultimodalDataProcessor(IOProcessor): class PrithviMultimodalDataProcessor(IOProcessor[ImagePrompt, ImageRequestOutput]):
indices = [0, 1, 2, 3, 4, 5] indices = [0, 1, 2, 3, 4, 5]
def __init__(self, vllm_config: VllmConfig): def __init__(self, vllm_config: VllmConfig):
...@@ -251,34 +243,15 @@ class PrithviMultimodalDataProcessor(IOProcessor): ...@@ -251,34 +243,15 @@ class PrithviMultimodalDataProcessor(IOProcessor):
self.requests_cache: dict[str, dict[str, Any]] = {} self.requests_cache: dict[str, dict[str, Any]] = {}
self.indices = DEFAULT_INPUT_INDICES self.indices = DEFAULT_INPUT_INDICES
def parse_request(self, request: Any) -> IOProcessorInput: def parse_data(self, data: object) -> ImagePrompt:
if type(request) is dict: if isinstance(data, dict):
image_prompt = ImagePrompt(**request) return ImagePrompt(**data)
return image_prompt
if isinstance(request, IOProcessorRequest): raise ValueError("Prompt data should be an `ImagePrompt`")
if not hasattr(request, "data"):
raise ValueError("missing 'data' field in OpenAIBaseModel Request")
request_data = request.data
if type(request_data) is dict:
return ImagePrompt(**request_data)
else:
raise ValueError("Unable to parse the request data")
raise ValueError("Unable to parse request")
def output_to_response(
self, plugin_output: IOProcessorOutput
) -> IOProcessorResponse:
return IOProcessorResponse(
request_id=plugin_output.request_id,
data=plugin_output,
)
def pre_process( def pre_process(
self, self,
prompt: IOProcessorInput, prompt: ImagePrompt,
request_id: str | None = None, request_id: str | None = None,
**kwargs, **kwargs,
) -> PromptType | Sequence[PromptType]: ) -> PromptType | Sequence[PromptType]:
...@@ -364,7 +337,7 @@ class PrithviMultimodalDataProcessor(IOProcessor): ...@@ -364,7 +337,7 @@ class PrithviMultimodalDataProcessor(IOProcessor):
model_output: Sequence[PoolingRequestOutput], model_output: Sequence[PoolingRequestOutput],
request_id: str | None = None, request_id: str | None = None,
**kwargs, **kwargs,
) -> IOProcessorOutput: ) -> ImageRequestOutput:
pred_imgs_list = [] pred_imgs_list = []
if request_id and (request_id in self.requests_cache): if request_id and (request_id in self.requests_cache):
...@@ -409,5 +382,7 @@ class PrithviMultimodalDataProcessor(IOProcessor): ...@@ -409,5 +382,7 @@ class PrithviMultimodalDataProcessor(IOProcessor):
) )
return ImageRequestOutput( return ImageRequestOutput(
type=out_format, format="tiff", data=out_data, request_id=request_id type=out_format,
format="tiff",
data=out_data,
) )
...@@ -38,9 +38,6 @@ class ImagePrompt(BaseModel): ...@@ -38,9 +38,6 @@ class ImagePrompt(BaseModel):
""" """
MultiModalPromptType = ImagePrompt
class ImageRequestOutput(BaseModel): class ImageRequestOutput(BaseModel):
""" """
The output data of an image request to vLLM. The output data of an image request to vLLM.
...@@ -54,4 +51,3 @@ class ImageRequestOutput(BaseModel): ...@@ -54,4 +51,3 @@ class ImageRequestOutput(BaseModel):
type: Literal["path", "b64_json"] type: Literal["path", "b64_json"]
format: str format: str
data: str data: str
request_id: str | None = None
...@@ -75,9 +75,7 @@ async def test_prithvi_mae_plugin_online( ...@@ -75,9 +75,7 @@ async def test_prithvi_mae_plugin_online(
# verify the output is formatted as expected for this plugin # verify the output is formatted as expected for this plugin
plugin_data = parsed_response.data plugin_data = parsed_response.data
assert all( assert all(plugin_data.get(attr) for attr in ["type", "format", "data"])
plugin_data.get(attr) for attr in ["type", "format", "data", "request_id"]
)
# We just check that the output is a valid base64 string. # We just check that the output is a valid base64 string.
# Raises an exception and fails the test if the string is corrupted. # Raises an exception and fails the test if the string is corrupted.
...@@ -110,9 +108,7 @@ def test_prithvi_mae_plugin_offline(vllm_runner, model_name: str): ...@@ -110,9 +108,7 @@ def test_prithvi_mae_plugin_offline(vllm_runner, model_name: str):
output = pooler_output[0].outputs output = pooler_output[0].outputs
# verify the output is formatted as expected for this plugin # verify the output is formatted as expected for this plugin
assert all( assert all(hasattr(output, attr) for attr in ["type", "format", "data"])
hasattr(output, attr) for attr in ["type", "format", "data", "request_id"]
)
# We just check that the output is a valid base64 string. # We just check that the output is a valid base64 string.
# Raises an exception and fails the test if the string is corrupted. # Raises an exception and fails the test if the string is corrupted.
......
...@@ -85,7 +85,6 @@ from vllm.tasks import PoolingTask ...@@ -85,7 +85,6 @@ from vllm.tasks import PoolingTask
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.tokenizers.mistral import MistralTokenizer from vllm.tokenizers.mistral import MistralTokenizer
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils.collection_utils import as_iter, is_list_of
from vllm.utils.counter import Counter from vllm.utils.counter import Counter
from vllm.v1.engine.llm_engine import LLMEngine from vllm.v1.engine.llm_engine import LLMEngine
from vllm.v1.sample.logits_processor import LogitsProcessor from vllm.v1.sample.logits_processor import LogitsProcessor
...@@ -95,6 +94,7 @@ if TYPE_CHECKING: ...@@ -95,6 +94,7 @@ if TYPE_CHECKING:
logger = init_logger(__name__) logger = init_logger(__name__)
_P = TypeVar("_P", bound=SamplingParams | PoolingParams | None)
_R = TypeVar("_R", default=Any) _R = TypeVar("_R", default=Any)
...@@ -1056,9 +1056,7 @@ class LLM: ...@@ -1056,9 +1056,7 @@ class LLM:
dict(truncate_prompt_tokens=truncate_prompt_tokens), dict(truncate_prompt_tokens=truncate_prompt_tokens),
) )
io_processor_prompt = False if use_io_processor := (isinstance(prompts, dict) and "data" in prompts):
if isinstance(prompts, dict) and "data" in prompts:
io_processor_prompt = True
if self.io_processor is None: if self.io_processor is None:
raise ValueError( raise ValueError(
"No IOProcessor plugin installed. Please refer " "No IOProcessor plugin installed. Please refer "
...@@ -1068,40 +1066,42 @@ class LLM: ...@@ -1068,40 +1066,42 @@ class LLM:
) )
# Validate the request data is valid for the loaded plugin # Validate the request data is valid for the loaded plugin
validated_prompt = self.io_processor.parse_request(prompts) validated_prompt = self.io_processor.parse_data(prompts)
# obtain the actual model prompts from the pre-processor # obtain the actual model prompts from the pre-processor
prompts = self.io_processor.pre_process(prompt=validated_prompt) prompts = self.io_processor.pre_process(prompt=validated_prompt)
prompts_seq = prompt_to_seq(prompts)
if io_processor_prompt: params_seq: Sequence[PoolingParams] = [
assert self.io_processor is not None self.io_processor.merge_pooling_params(param)
if is_list_of(pooling_params, PoolingParams): for param in self._params_to_seq(
validated_pooling_params: list[PoolingParams] = [] pooling_params,
for param in as_iter(pooling_params): len(prompts_seq),
validated_pooling_params.append(
self.io_processor.validate_or_generate_params(param)
)
pooling_params = validated_pooling_params
else:
assert not isinstance(pooling_params, Sequence)
pooling_params = self.io_processor.validate_or_generate_params(
pooling_params
) )
]
if pooling_params is None: for p in params_seq:
# Use default pooling params. if p.task is None:
pooling_params = PoolingParams() p.task = "plugin"
else:
for param in as_iter(pooling_params): if pooling_params is None:
if param.task is None: # Use default pooling params.
param.task = pooling_task pooling_params = PoolingParams()
elif param.task != pooling_task:
msg = f"You cannot overwrite {param.task=!r} with {pooling_task=!r}!" prompts_seq = prompt_to_seq(prompts)
raise ValueError(msg) params_seq = self._params_to_seq(pooling_params, len(prompts_seq))
for param in params_seq:
if param.task is None:
param.task = pooling_task
elif param.task != pooling_task:
msg = (
f"You cannot overwrite {param.task=!r} with {pooling_task=!r}!"
)
raise ValueError(msg)
outputs = self._run_completion( outputs = self._run_completion(
prompts=prompts, prompts=prompts_seq,
params=pooling_params, params=params_seq,
use_tqdm=use_tqdm, use_tqdm=use_tqdm,
lora_request=lora_request, lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
...@@ -1111,12 +1111,10 @@ class LLM: ...@@ -1111,12 +1111,10 @@ class LLM:
outputs, PoolingRequestOutput outputs, PoolingRequestOutput
) )
if io_processor_prompt: if use_io_processor:
# get the post-processed model outputs # get the post-processed model outputs
assert self.io_processor is not None assert self.io_processor is not None
processed_outputs = self.io_processor.post_process( processed_outputs = self.io_processor.post_process(model_outputs)
model_output=model_outputs
)
return [ return [
PoolingRequestOutput[Any]( PoolingRequestOutput[Any](
...@@ -1662,11 +1660,9 @@ class LLM: ...@@ -1662,11 +1660,9 @@ class LLM:
def _params_to_seq( def _params_to_seq(
self, self,
params: SamplingParams params: _P | Sequence[_P],
| PoolingParams
| Sequence[SamplingParams | PoolingParams],
num_requests: int, num_requests: int,
) -> Sequence[SamplingParams | PoolingParams]: ) -> Sequence[_P]:
if isinstance(params, Sequence): if isinstance(params, Sequence):
if len(params) != num_requests: if len(params) != num_requests:
raise ValueError( raise ValueError(
......
...@@ -100,9 +100,6 @@ class IOProcessorRequest(PoolingBasicRequestMixin, EncodingRequestMixin, Generic ...@@ -100,9 +100,6 @@ class IOProcessorRequest(PoolingBasicRequestMixin, EncodingRequestMixin, Generic
data: T data: T
task: PoolingTask = "plugin" task: PoolingTask = "plugin"
def to_pooling_params(self):
return PoolingParams(task=self.task)
class IOProcessorResponse(OpenAIBaseModel, Generic[T]): class IOProcessorResponse(OpenAIBaseModel, Generic[T]):
request_id: str | None = None request_id: str | None = None
......
...@@ -85,7 +85,6 @@ class OpenAIServingPooling(OpenAIServing): ...@@ -85,7 +85,6 @@ class OpenAIServingPooling(OpenAIServing):
request_id = f"pool-{self._base_request_id(raw_request)}" request_id = f"pool-{self._base_request_id(raw_request)}"
created_time = int(time.time()) created_time = int(time.time())
is_io_processor_request = isinstance(request, IOProcessorRequest)
try: try:
lora_request = self._maybe_get_adapters(request) lora_request = self._maybe_get_adapters(request)
...@@ -95,7 +94,7 @@ class OpenAIServingPooling(OpenAIServing): ...@@ -95,7 +94,7 @@ class OpenAIServingPooling(OpenAIServing):
) )
engine_prompts: Sequence[PromptType | TokPrompt] engine_prompts: Sequence[PromptType | TokPrompt]
if is_io_processor_request: if use_io_processor := isinstance(request, IOProcessorRequest):
if self.io_processor is None: if self.io_processor is None:
raise ValueError( raise ValueError(
"No IOProcessor plugin installed. Please refer " "No IOProcessor plugin installed. Please refer "
...@@ -104,7 +103,7 @@ class OpenAIServingPooling(OpenAIServing): ...@@ -104,7 +103,7 @@ class OpenAIServingPooling(OpenAIServing):
"offline inference example for more details." "offline inference example for more details."
) )
validated_prompt = self.io_processor.parse_request(request) validated_prompt = self.io_processor.parse_data(request.data)
raw_prompts = await self.io_processor.pre_process_async( raw_prompts = await self.io_processor.pre_process_async(
prompt=validated_prompt, request_id=request_id prompt=validated_prompt, request_id=request_id
...@@ -141,13 +140,18 @@ class OpenAIServingPooling(OpenAIServing): ...@@ -141,13 +140,18 @@ class OpenAIServingPooling(OpenAIServing):
# Schedule the request and get the result generator. # Schedule the request and get the result generator.
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = [] generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
try: try:
if is_io_processor_request: if use_io_processor:
assert self.io_processor is not None and isinstance( assert self.io_processor is not None
request, IOProcessorRequest
) pooling_params = self.io_processor.merge_pooling_params()
pooling_params = self.io_processor.validate_or_generate_params() if pooling_params.task is None:
pooling_params.task = "plugin"
tokenization_kwargs: dict[str, Any] = {}
else: else:
pooling_params = request.to_pooling_params() pooling_params = request.to_pooling_params() # type: ignore
tok_params = request.build_tok_params(self.model_config) # type: ignore
tokenization_kwargs = tok_params.get_encode_kwargs()
for i, engine_prompt in enumerate(engine_prompts): for i, engine_prompt in enumerate(engine_prompts):
request_id_item = f"{request_id}-{i}" request_id_item = f"{request_id}-{i}"
...@@ -165,12 +169,6 @@ class OpenAIServingPooling(OpenAIServing): ...@@ -165,12 +169,6 @@ class OpenAIServingPooling(OpenAIServing):
else await self._get_trace_headers(raw_request.headers) else await self._get_trace_headers(raw_request.headers)
) )
if is_io_processor_request:
tokenization_kwargs: dict[str, Any] = {}
else:
tok_params = request.build_tok_params(self.model_config) # type: ignore
tokenization_kwargs = tok_params.get_encode_kwargs()
generator = self.engine_client.encode( generator = self.engine_client.encode(
engine_prompt, engine_prompt,
pooling_params, pooling_params,
...@@ -187,13 +185,31 @@ class OpenAIServingPooling(OpenAIServing): ...@@ -187,13 +185,31 @@ class OpenAIServingPooling(OpenAIServing):
result_generator = merge_async_iterators(*generators) result_generator = merge_async_iterators(*generators)
if is_io_processor_request: if use_io_processor:
assert self.io_processor is not None assert self.io_processor is not None
output = await self.io_processor.post_process_async( output = await self.io_processor.post_process_async(
model_output=result_generator, result_generator,
request_id=request_id, request_id=request_id,
) )
return self.io_processor.output_to_response(output)
if callable(
output_to_response := getattr(
self.io_processor, "output_to_response", None
)
):
logger.warning_once(
"`IOProcessor.output_to_response` is deprecated. To ensure "
"consistency between offline and online APIs, "
"`IOProcessorResponse` will become a transparent wrapper "
"around output data from v0.19 onwards.",
)
if hasattr(output, "request_id") and output.request_id is None:
output.request_id = request_id # type: ignore
return output_to_response(output) # type: ignore
return IOProcessorResponse(request_id=request_id, data=output)
assert isinstance(request, (PoolingCompletionRequest, PoolingChatRequest)) assert isinstance(request, (PoolingCompletionRequest, PoolingChatRequest))
num_prompts = len(engine_prompts) num_prompts = len(engine_prompts)
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator, Sequence from collections.abc import AsyncGenerator, Sequence
from typing import Any, Generic, TypeVar from typing import Generic, TypeVar
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.entrypoints.pooling.pooling.protocol import IOProcessorResponse
from vllm.inputs.data import PromptType from vllm.inputs.data import PromptType
from vllm.outputs import PoolingRequestOutput from vllm.outputs import PoolingRequestOutput
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
...@@ -18,8 +17,68 @@ IOProcessorOutput = TypeVar("IOProcessorOutput") ...@@ -18,8 +17,68 @@ IOProcessorOutput = TypeVar("IOProcessorOutput")
class IOProcessor(ABC, Generic[IOProcessorInput, IOProcessorOutput]): class IOProcessor(ABC, Generic[IOProcessorInput, IOProcessorOutput]):
def __init__(self, vllm_config: VllmConfig): def __init__(self, vllm_config: VllmConfig):
super().__init__()
self.vllm_config = vllm_config self.vllm_config = vllm_config
def parse_data(self, data: object) -> IOProcessorInput:
if callable(parse_request := getattr(self, "parse_request", None)):
warnings.warn(
"`parse_request` has been renamed to `parse_data`. "
"Please update your IO Processor Plugin to use the new name. "
"The old name will be removed in v0.19.",
DeprecationWarning,
stacklevel=2,
)
return parse_request(data) # type: ignore
raise NotImplementedError
def merge_sampling_params(
self,
params: SamplingParams | None = None,
) -> SamplingParams:
if callable(
validate_or_generate_params := getattr(
self, "validate_or_generate_params", None
)
):
warnings.warn(
"`validate_or_generate_params` has been split into "
"`merge_sampling_params` and `merge_pooling_params`."
"Please update your IO Processor Plugin to use the new methods. "
"The old name will be removed in v0.19.",
DeprecationWarning,
stacklevel=2,
)
return validate_or_generate_params(params) # type: ignore
return params or SamplingParams()
def merge_pooling_params(
self,
params: PoolingParams | None = None,
) -> PoolingParams:
if callable(
validate_or_generate_params := getattr(
self, "validate_or_generate_params", None
)
):
warnings.warn(
"`validate_or_generate_params` has been split into "
"`merge_sampling_params` and `merge_pooling_params`."
"Please update your IO Processor Plugin to use the new methods. "
"The old name will be removed in v0.19.",
DeprecationWarning,
stacklevel=2,
)
return validate_or_generate_params(params) # type: ignore
return params or PoolingParams(task="plugin")
@abstractmethod @abstractmethod
def pre_process( def pre_process(
self, self,
...@@ -59,19 +118,4 @@ class IOProcessor(ABC, Generic[IOProcessorInput, IOProcessorOutput]): ...@@ -59,19 +118,4 @@ class IOProcessor(ABC, Generic[IOProcessorInput, IOProcessorOutput]):
[(i, item) async for i, item in model_output], key=lambda output: output[0] [(i, item) async for i, item in model_output], key=lambda output: output[0]
) )
collected_output = [output[1] for output in sorted_output] collected_output = [output[1] for output in sorted_output]
return self.post_process(collected_output, request_id, **kwargs) return self.post_process(collected_output, request_id=request_id, **kwargs)
@abstractmethod
def parse_request(self, request: Any) -> IOProcessorInput:
raise NotImplementedError
def validate_or_generate_params(
self, params: SamplingParams | PoolingParams | None = None
) -> SamplingParams | PoolingParams:
return params or PoolingParams()
@abstractmethod
def output_to_response(
self, plugin_output: IOProcessorOutput
) -> IOProcessorResponse:
raise NotImplementedError
...@@ -51,12 +51,6 @@ def as_list(maybe_list: Iterable[T]) -> list[T]: ...@@ -51,12 +51,6 @@ def as_list(maybe_list: Iterable[T]) -> list[T]:
return maybe_list if isinstance(maybe_list, list) else list(maybe_list) return maybe_list if isinstance(maybe_list, list) else list(maybe_list)
def as_iter(obj: T | Iterable[T]) -> Iterable[T]:
if isinstance(obj, str) or not isinstance(obj, Iterable):
return [obj] # type: ignore[list-item]
return obj
def is_list_of( def is_list_of(
value: object, value: object,
typ: type[T] | tuple[type[T], ...], typ: type[T] | tuple[type[T], ...],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment