interface.py 4.22 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import warnings
4
5
from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator, Sequence
6
from typing import Generic, TypeVar
7
8
9
10

from vllm.config import VllmConfig
from vllm.inputs.data import PromptType
from vllm.outputs import PoolingRequestOutput
11
from vllm.pooling_params import PoolingParams
12
from vllm.renderers import BaseRenderer
13
from vllm.sampling_params import SamplingParams
14

15
16
IOProcessorInput = TypeVar("IOProcessorInput")
IOProcessorOutput = TypeVar("IOProcessorOutput")
17
18
19


class IOProcessor(ABC, Generic[IOProcessorInput, IOProcessorOutput]):
20
21
    """Abstract interface for pre/post-processing of engine I/O."""

22
    def __init__(self, vllm_config: VllmConfig, renderer: BaseRenderer):
23
24
        super().__init__()

25
26
        self.vllm_config = vllm_config

27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
    def parse_data(self, data: object) -> IOProcessorInput:
        if callable(parse_request := getattr(self, "parse_request", None)):
            warnings.warn(
                "`parse_request` has been renamed to `parse_data`. "
                "Please update your IO Processor Plugin to use the new name. "
                "The old name will be removed in v0.19.",
                DeprecationWarning,
                stacklevel=2,
            )

            return parse_request(data)  # type: ignore

        raise NotImplementedError

    def merge_sampling_params(
        self,
        params: SamplingParams | None = None,
    ) -> SamplingParams:
        if callable(
            validate_or_generate_params := getattr(
                self, "validate_or_generate_params", None
            )
        ):
            warnings.warn(
                "`validate_or_generate_params` has been split into "
                "`merge_sampling_params` and `merge_pooling_params`."
                "Please update your IO Processor Plugin to use the new methods. "
                "The old name will be removed in v0.19.",
                DeprecationWarning,
                stacklevel=2,
            )

            return validate_or_generate_params(params)  # type: ignore

        return params or SamplingParams()

    def merge_pooling_params(
        self,
        params: PoolingParams | None = None,
    ) -> PoolingParams:
        if callable(
            validate_or_generate_params := getattr(
                self, "validate_or_generate_params", None
            )
        ):
            warnings.warn(
                "`validate_or_generate_params` has been split into "
                "`merge_sampling_params` and `merge_pooling_params`."
                "Please update your IO Processor Plugin to use the new methods. "
                "The old name will be removed in v0.19.",
                DeprecationWarning,
                stacklevel=2,
            )

            return validate_or_generate_params(params)  # type: ignore

        return params or PoolingParams(task="plugin")

85
86
87
88
    @abstractmethod
    def pre_process(
        self,
        prompt: IOProcessorInput,
89
        request_id: str | None = None,
90
        **kwargs,
91
    ) -> PromptType | Sequence[PromptType]:
92
93
94
95
96
        raise NotImplementedError

    async def pre_process_async(
        self,
        prompt: IOProcessorInput,
97
        request_id: str | None = None,
98
        **kwargs,
99
    ) -> PromptType | Sequence[PromptType]:
100
101
102
        return self.pre_process(prompt, request_id, **kwargs)

    @abstractmethod
103
104
105
    def post_process(
        self,
        model_output: Sequence[PoolingRequestOutput],
106
        request_id: str | None = None,
107
108
        **kwargs,
    ) -> IOProcessorOutput:
109
110
111
112
113
        raise NotImplementedError

    async def post_process_async(
        self,
        model_output: AsyncGenerator[tuple[int, PoolingRequestOutput]],
114
        request_id: str | None = None,
115
116
        **kwargs,
    ) -> IOProcessorOutput:
117
118
119
        # We cannot guarantee outputs are returned in the same order they were
        # fed to vLLM.
        # Let's sort them by id before post_processing
120
121
122
        sorted_output = sorted(
            [(i, item) async for i, item in model_output], key=lambda output: output[0]
        )
123
        collected_output = [output[1] for output in sorted_output]
124
        return self.post_process(collected_output, request_id=request_id, **kwargs)