"vscode:/vscode.git/clone" did not exist on "f12b20deccbc6c8bb5cdeac053d75178341c66c1"
interface.py 2.51 KB
Newer Older
1
2
3
4
5
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator, Sequence
6
from typing import Any, Generic, TypeVar
7
8
9
10
11

from vllm.config import VllmConfig
from vllm.entrypoints.openai.protocol import IOProcessorResponse
from vllm.inputs.data import PromptType
from vllm.outputs import PoolingRequestOutput
12
13
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
14

15
16
IOProcessorInput = TypeVar("IOProcessorInput")
IOProcessorOutput = TypeVar("IOProcessorOutput")
17
18
19
20
21
22
23
24
25
26


class IOProcessor(ABC, Generic[IOProcessorInput, IOProcessorOutput]):
    def __init__(self, vllm_config: VllmConfig):
        self.vllm_config = vllm_config

    @abstractmethod
    def pre_process(
        self,
        prompt: IOProcessorInput,
27
        request_id: str | None = None,
28
        **kwargs,
29
    ) -> PromptType | Sequence[PromptType]:
30
31
32
33
34
        raise NotImplementedError

    async def pre_process_async(
        self,
        prompt: IOProcessorInput,
35
        request_id: str | None = None,
36
        **kwargs,
37
    ) -> PromptType | Sequence[PromptType]:
38
39
40
        return self.pre_process(prompt, request_id, **kwargs)

    @abstractmethod
41
42
43
    def post_process(
        self,
        model_output: Sequence[PoolingRequestOutput],
44
        request_id: str | None = None,
45
46
        **kwargs,
    ) -> IOProcessorOutput:
47
48
49
50
51
        raise NotImplementedError

    async def post_process_async(
        self,
        model_output: AsyncGenerator[tuple[int, PoolingRequestOutput]],
52
        request_id: str | None = None,
53
54
        **kwargs,
    ) -> IOProcessorOutput:
55
56
57
        # We cannot guarantee outputs are returned in the same order they were
        # fed to vLLM.
        # Let's sort them by id before post_processing
58
59
60
        sorted_output = sorted(
            [(i, item) async for i, item in model_output], key=lambda output: output[0]
        )
61
        collected_output = [output[1] for output in sorted_output]
62
63
64
65
66
67
        return self.post_process(collected_output, request_id, **kwargs)

    @abstractmethod
    def parse_request(self, request: Any) -> IOProcessorInput:
        raise NotImplementedError

68
69
70
71
72
    def validate_or_generate_params(
        self, params: SamplingParams | PoolingParams | None = None
    ) -> SamplingParams | PoolingParams:
        return params or PoolingParams()

73
74
    @abstractmethod
    def output_to_response(
75
76
        self, plugin_output: IOProcessorOutput
    ) -> IOProcessorResponse:
77
        raise NotImplementedError