interface.py 4.09 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import warnings
4
5
from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator, Sequence
6
from typing import Generic, TypeVar
7
8
9
10

from vllm.config import VllmConfig
from vllm.inputs.data import PromptType
from vllm.outputs import PoolingRequestOutput
11
12
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
13

14
15
IOProcessorInput = TypeVar("IOProcessorInput")
IOProcessorOutput = TypeVar("IOProcessorOutput")
16
17
18
19


class IOProcessor(ABC, Generic[IOProcessorInput, IOProcessorOutput]):
    def __init__(self, vllm_config: VllmConfig):
20
21
        super().__init__()

22
23
        self.vllm_config = vllm_config

24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
    def parse_data(self, data: object) -> IOProcessorInput:
        if callable(parse_request := getattr(self, "parse_request", None)):
            warnings.warn(
                "`parse_request` has been renamed to `parse_data`. "
                "Please update your IO Processor Plugin to use the new name. "
                "The old name will be removed in v0.19.",
                DeprecationWarning,
                stacklevel=2,
            )

            return parse_request(data)  # type: ignore

        raise NotImplementedError

    def merge_sampling_params(
        self,
        params: SamplingParams | None = None,
    ) -> SamplingParams:
        if callable(
            validate_or_generate_params := getattr(
                self, "validate_or_generate_params", None
            )
        ):
            warnings.warn(
                "`validate_or_generate_params` has been split into "
                "`merge_sampling_params` and `merge_pooling_params`."
                "Please update your IO Processor Plugin to use the new methods. "
                "The old name will be removed in v0.19.",
                DeprecationWarning,
                stacklevel=2,
            )

            return validate_or_generate_params(params)  # type: ignore

        return params or SamplingParams()

    def merge_pooling_params(
        self,
        params: PoolingParams | None = None,
    ) -> PoolingParams:
        if callable(
            validate_or_generate_params := getattr(
                self, "validate_or_generate_params", None
            )
        ):
            warnings.warn(
                "`validate_or_generate_params` has been split into "
                "`merge_sampling_params` and `merge_pooling_params`."
                "Please update your IO Processor Plugin to use the new methods. "
                "The old name will be removed in v0.19.",
                DeprecationWarning,
                stacklevel=2,
            )

            return validate_or_generate_params(params)  # type: ignore

        return params or PoolingParams(task="plugin")

82
83
84
85
    @abstractmethod
    def pre_process(
        self,
        prompt: IOProcessorInput,
86
        request_id: str | None = None,
87
        **kwargs,
88
    ) -> PromptType | Sequence[PromptType]:
89
90
91
92
93
        raise NotImplementedError

    async def pre_process_async(
        self,
        prompt: IOProcessorInput,
94
        request_id: str | None = None,
95
        **kwargs,
96
    ) -> PromptType | Sequence[PromptType]:
97
98
99
        return self.pre_process(prompt, request_id, **kwargs)

    @abstractmethod
100
101
102
    def post_process(
        self,
        model_output: Sequence[PoolingRequestOutput],
103
        request_id: str | None = None,
104
105
        **kwargs,
    ) -> IOProcessorOutput:
106
107
108
109
110
        raise NotImplementedError

    async def post_process_async(
        self,
        model_output: AsyncGenerator[tuple[int, PoolingRequestOutput]],
111
        request_id: str | None = None,
112
113
        **kwargs,
    ) -> IOProcessorOutput:
114
115
116
        # We cannot guarantee outputs are returned in the same order they were
        # fed to vLLM.
        # Let's sort them by id before post_processing
117
118
119
        sorted_output = sorted(
            [(i, item) async for i, item in model_output], key=lambda output: output[0]
        )
120
        collected_output = [output[1] for output in sorted_output]
121
        return self.post_process(collected_output, request_id=request_id, **kwargs)