interface.py 4.16 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import warnings
4
5
from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator, Sequence
6
from typing import Generic, TypeVar
7
8
9
10

from vllm.config import VllmConfig
from vllm.inputs.data import PromptType
from vllm.outputs import PoolingRequestOutput
11
12
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
13

14
15
IOProcessorInput = TypeVar("IOProcessorInput")
IOProcessorOutput = TypeVar("IOProcessorOutput")
16
17
18


class IOProcessor(ABC, Generic[IOProcessorInput, IOProcessorOutput]):
19
20
    """Abstract interface for pre/post-processing of engine I/O."""

21
    def __init__(self, vllm_config: VllmConfig):
22
23
        super().__init__()

24
25
        self.vllm_config = vllm_config

26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
    def parse_data(self, data: object) -> IOProcessorInput:
        if callable(parse_request := getattr(self, "parse_request", None)):
            warnings.warn(
                "`parse_request` has been renamed to `parse_data`. "
                "Please update your IO Processor Plugin to use the new name. "
                "The old name will be removed in v0.19.",
                DeprecationWarning,
                stacklevel=2,
            )

            return parse_request(data)  # type: ignore

        raise NotImplementedError

    def merge_sampling_params(
        self,
        params: SamplingParams | None = None,
    ) -> SamplingParams:
        if callable(
            validate_or_generate_params := getattr(
                self, "validate_or_generate_params", None
            )
        ):
            warnings.warn(
                "`validate_or_generate_params` has been split into "
                "`merge_sampling_params` and `merge_pooling_params`."
                "Please update your IO Processor Plugin to use the new methods. "
                "The old name will be removed in v0.19.",
                DeprecationWarning,
                stacklevel=2,
            )

            return validate_or_generate_params(params)  # type: ignore

        return params or SamplingParams()

    def merge_pooling_params(
        self,
        params: PoolingParams | None = None,
    ) -> PoolingParams:
        if callable(
            validate_or_generate_params := getattr(
                self, "validate_or_generate_params", None
            )
        ):
            warnings.warn(
                "`validate_or_generate_params` has been split into "
                "`merge_sampling_params` and `merge_pooling_params`."
                "Please update your IO Processor Plugin to use the new methods. "
                "The old name will be removed in v0.19.",
                DeprecationWarning,
                stacklevel=2,
            )

            return validate_or_generate_params(params)  # type: ignore

        return params or PoolingParams(task="plugin")

84
85
86
87
    @abstractmethod
    def pre_process(
        self,
        prompt: IOProcessorInput,
88
        request_id: str | None = None,
89
        **kwargs,
90
    ) -> PromptType | Sequence[PromptType]:
91
92
93
94
95
        raise NotImplementedError

    async def pre_process_async(
        self,
        prompt: IOProcessorInput,
96
        request_id: str | None = None,
97
        **kwargs,
98
    ) -> PromptType | Sequence[PromptType]:
99
100
101
        return self.pre_process(prompt, request_id, **kwargs)

    @abstractmethod
102
103
104
    def post_process(
        self,
        model_output: Sequence[PoolingRequestOutput],
105
        request_id: str | None = None,
106
107
        **kwargs,
    ) -> IOProcessorOutput:
108
109
110
111
112
        raise NotImplementedError

    async def post_process_async(
        self,
        model_output: AsyncGenerator[tuple[int, PoolingRequestOutput]],
113
        request_id: str | None = None,
114
115
        **kwargs,
    ) -> IOProcessorOutput:
116
117
118
        # We cannot guarantee outputs are returned in the same order they were
        # fed to vLLM.
        # Let's sort them by id before post_processing
119
120
121
        sorted_output = sorted(
            [(i, item) async for i, item in model_output], key=lambda output: output[0]
        )
122
        collected_output = [output[1] for output in sorted_output]
123
        return self.post_process(collected_output, request_id=request_id, **kwargs)