interface.py 2.02 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator, Sequence
from typing import Any, Generic, Optional, TypeVar, Union

from vllm.config import VllmConfig
from vllm.entrypoints.openai.protocol import IOProcessorResponse
from vllm.inputs.data import PromptType
from vllm.outputs import PoolingRequestOutput

IOProcessorInput = TypeVar('IOProcessorInput')
IOProcessorOutput = TypeVar('IOProcessorOutput')


class IOProcessor(ABC, Generic[IOProcessorInput, IOProcessorOutput]):

    def __init__(self, vllm_config: VllmConfig):
        self.vllm_config = vllm_config

    @abstractmethod
    def pre_process(
        self,
        prompt: IOProcessorInput,
        request_id: Optional[str] = None,
        **kwargs,
    ) -> Union[PromptType, Sequence[PromptType]]:
        raise NotImplementedError

    async def pre_process_async(
        self,
        prompt: IOProcessorInput,
        request_id: Optional[str] = None,
        **kwargs,
    ) -> Union[PromptType, Sequence[PromptType]]:
        return self.pre_process(prompt, request_id, **kwargs)

    @abstractmethod
    def post_process(self,
                     model_output: Sequence[PoolingRequestOutput],
                     request_id: Optional[str] = None,
                     **kwargs) -> IOProcessorOutput:
        raise NotImplementedError

    async def post_process_async(
        self,
        model_output: AsyncGenerator[tuple[int, PoolingRequestOutput]],
        request_id: Optional[str] = None,
        **kwargs,
    ) -> IOProcessorOutput:
        collected_output = [item async for i, item in model_output]
        return self.post_process(collected_output, request_id, **kwargs)

    @abstractmethod
    def parse_request(self, request: Any) -> IOProcessorInput:
        raise NotImplementedError

    @abstractmethod
    def output_to_response(
            self, plugin_output: IOProcessorOutput) -> IOProcessorResponse:
        raise NotImplementedError