interface.py 2.24 KB
Newer Older
1
2
3
4
5
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator, Sequence
6
from typing import Any, Generic, TypeVar
7
8
9
10
11
12

from vllm.config import VllmConfig
from vllm.entrypoints.openai.protocol import IOProcessorResponse
from vllm.inputs.data import PromptType
from vllm.outputs import PoolingRequestOutput

13
14
IOProcessorInput = TypeVar("IOProcessorInput")
IOProcessorOutput = TypeVar("IOProcessorOutput")
15
16
17
18
19
20
21
22
23
24


class IOProcessor(ABC, Generic[IOProcessorInput, IOProcessorOutput]):
    def __init__(self, vllm_config: VllmConfig):
        self.vllm_config = vllm_config

    @abstractmethod
    def pre_process(
        self,
        prompt: IOProcessorInput,
25
        request_id: str | None = None,
26
        **kwargs,
27
    ) -> PromptType | Sequence[PromptType]:
28
29
30
31
32
        raise NotImplementedError

    async def pre_process_async(
        self,
        prompt: IOProcessorInput,
33
        request_id: str | None = None,
34
        **kwargs,
35
    ) -> PromptType | Sequence[PromptType]:
36
37
38
        return self.pre_process(prompt, request_id, **kwargs)

    @abstractmethod
39
40
41
    def post_process(
        self,
        model_output: Sequence[PoolingRequestOutput],
42
        request_id: str | None = None,
43
44
        **kwargs,
    ) -> IOProcessorOutput:
45
46
47
48
49
        raise NotImplementedError

    async def post_process_async(
        self,
        model_output: AsyncGenerator[tuple[int, PoolingRequestOutput]],
50
        request_id: str | None = None,
51
52
        **kwargs,
    ) -> IOProcessorOutput:
53
54
55
        # We cannot guarantee outputs are returned in the same order they were
        # fed to vLLM.
        # Let's sort them by id before post_processing
56
57
58
        sorted_output = sorted(
            [(i, item) async for i, item in model_output], key=lambda output: output[0]
        )
59
        collected_output = [output[1] for output in sorted_output]
60
61
62
63
64
65
66
67
        return self.post_process(collected_output, request_id, **kwargs)

    @abstractmethod
    def parse_request(self, request: Any) -> IOProcessorInput:
        raise NotImplementedError

    @abstractmethod
    def output_to_response(
68
69
        self, plugin_output: IOProcessorOutput
    ) -> IOProcessorResponse:
70
        raise NotImplementedError