serving.py 7.17 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4

from http import HTTPStatus
5
from typing import Final, cast
6

7
import jinja2
8
9
10
11
import numpy as np
from fastapi import Request

from vllm.engine.protocol import EngineClient
12
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
13
from vllm.entrypoints.logger import RequestLogger
14
15
from vllm.entrypoints.openai.engine.protocol import ErrorResponse, UsageInfo
from vllm.entrypoints.openai.engine.serving import OpenAIServing, ServeContext
16
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
17
18
19
20
21
22
23
from vllm.entrypoints.pooling.classify.protocol import (
    ClassificationChatRequest,
    ClassificationCompletionRequest,
    ClassificationData,
    ClassificationRequest,
    ClassificationResponse,
)
24
from vllm.entrypoints.renderer import RenderConfig
25
26
from vllm.logger import init_logger
from vllm.outputs import ClassificationOutput, PoolingRequestOutput
27
from vllm.pooling_params import PoolingParams
28
29
30
31

logger = init_logger(__name__)


32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
ClassificationServeContext = ServeContext[ClassificationRequest]


class ServingClassification(OpenAIServing):
    request_id_prefix = "classify"

    def __init__(
        self,
        engine_client: EngineClient,
        models: OpenAIServingModels,
        *,
        request_logger: RequestLogger | None,
        chat_template: str | None = None,
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
        trust_request_chat_template: bool = False,
        log_error_stack: bool = False,
    ) -> None:
        super().__init__(
            engine_client=engine_client,
            models=models,
            request_logger=request_logger,
            log_error_stack=log_error_stack,
        )

        self.chat_template = chat_template
        self.chat_template_content_format: Final = chat_template_content_format
        self.trust_request_chat_template = trust_request_chat_template
59

60
61
    async def _preprocess(
        self,
62
        ctx: ClassificationServeContext,
63
    ) -> ErrorResponse | None:
64
65
66
67
68
        """
        Process classification inputs: tokenize text, resolve adapters,
        and prepare model-specific inputs.
        """
        try:
69
70
71
72
73
74
75
            ctx.lora_request = self._maybe_get_adapters(ctx.request)

            if isinstance(ctx.request, ClassificationChatRequest):
                error_check_ret = self._validate_chat_template(
                    request_chat_template=ctx.request.chat_template,
                    chat_template_kwargs=ctx.request.chat_template_kwargs,
                    trust_request_chat_template=self.trust_request_chat_template,
76
                )
77
78
                if error_check_ret:
                    return error_check_ret
79

80
                _, engine_prompts = await self._preprocess_chat(
81
                    ctx.request,
82
                    self.renderer,
83
84
85
86
87
88
                    ctx.request.messages,
                    chat_template=ctx.request.chat_template or self.chat_template,
                    chat_template_content_format=self.chat_template_content_format,
                    add_generation_prompt=ctx.request.add_generation_prompt,
                    continue_final_message=ctx.request.continue_final_message,
                    add_special_tokens=ctx.request.add_special_tokens,
89
90
91
                )
                ctx.engine_prompts = engine_prompts

92
93
            elif isinstance(ctx.request, ClassificationCompletionRequest):
                input_data = ctx.request.input
94
95
96
97
98
99
100
101
102
                if input_data in (None, ""):
                    return self.create_error_response(
                        "Input or messages must be provided",
                        status_code=HTTPStatus.BAD_REQUEST,
                    )
                if isinstance(input_data, list) and not input_data:
                    ctx.engine_prompts = []
                    return None

103
                renderer = self._get_completion_renderer()
104
105
106
                prompt_input = cast(str | list[str], input_data)
                ctx.engine_prompts = await renderer.render_prompt(
                    prompt_or_prompts=prompt_input,
107
                    config=self._build_render_config(ctx.request),
108
109
                )
            else:
110
                return self.create_error_response("Invalid classification request type")
111
112
113

            return None

114
        except (ValueError, TypeError, jinja2.TemplateError) as e:
115
116
117
118
119
            logger.exception("Error in preprocessing prompt inputs")
            return self.create_error_response(str(e))

    def _build_response(
        self,
120
        ctx: ClassificationServeContext,
121
    ) -> ClassificationResponse | ErrorResponse:
122
123
124
125
        """
        Convert model outputs to a formatted classification response
        with probabilities and labels.
        """
126
127
        id2label = getattr(self.model_config.hf_config, "id2label", {})

128
129
130
        items: list[ClassificationData] = []
        num_prompt_tokens = 0

131
        final_res_batch_checked = cast(list[PoolingRequestOutput], ctx.final_res_batch)
132
133
134
135
136
137

        for idx, final_res in enumerate(final_res_batch_checked):
            classify_res = ClassificationOutput.from_base(final_res.outputs)

            probs = classify_res.probs
            predicted_index = int(np.argmax(probs))
138
            label = id2label.get(predicted_index)
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163

            item = ClassificationData(
                index=idx,
                label=label,
                probs=probs,
                num_classes=len(probs),
            )

            items.append(item)
            prompt_token_ids = final_res.prompt_token_ids
            num_prompt_tokens += len(prompt_token_ids)

        usage = UsageInfo(
            prompt_tokens=num_prompt_tokens,
            total_tokens=num_prompt_tokens,
        )

        return ClassificationResponse(
            id=ctx.request_id,
            created=ctx.created_time,
            model=ctx.model_name,
            data=items,
            usage=usage,
        )

164
    def _build_render_config(self, request: ClassificationRequest) -> RenderConfig:
165
166
        return RenderConfig(
            max_length=self.max_model_len,
167
            truncate_prompt_tokens=request.truncate_prompt_tokens,
168
            add_special_tokens=request.add_special_tokens,
169
        )
170

171
172
173
174
    async def create_classify(
        self,
        request: ClassificationRequest,
        raw_request: Request,
175
    ) -> ClassificationResponse | ErrorResponse:
176
        model_name = self.models.model_name()
177
        request_id = f"{self.request_id_prefix}-{self._base_request_id(raw_request)}"
178
179
180
181
182
183
184
185

        ctx = ClassificationServeContext(
            request=request,
            raw_request=raw_request,
            model_name=model_name,
            request_id=request_id,
        )

186
        return await self.handle(ctx)  # type: ignore[return-value]
187
188
189

    def _create_pooling_params(
        self,
190
        ctx: ClassificationServeContext,
191
    ) -> PoolingParams | ErrorResponse:
192
193
194
195
196
197
198
199
200
201
        pooling_params = super()._create_pooling_params(ctx)
        if isinstance(pooling_params, ErrorResponse):
            return pooling_params

        try:
            pooling_params.verify("classify", self.model_config)
        except ValueError as e:
            return self.create_error_response(str(e))

        return pooling_params