serving.py 5.95 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from typing import Final, TypeAlias
5

6
import jinja2
7
8
9
10
import numpy as np
from fastapi import Request

from vllm.engine.protocol import EngineClient
11
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
12
from vllm.entrypoints.logger import RequestLogger
13
14
from vllm.entrypoints.openai.engine.protocol import ErrorResponse, UsageInfo
from vllm.entrypoints.openai.engine.serving import OpenAIServing, ServeContext
15
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
16
17
18
19
20
21
22
from vllm.entrypoints.pooling.classify.protocol import (
    ClassificationChatRequest,
    ClassificationCompletionRequest,
    ClassificationData,
    ClassificationRequest,
    ClassificationResponse,
)
23
from vllm.logger import init_logger
24
from vllm.outputs import ClassificationOutput
25
from vllm.pooling_params import PoolingParams
26
27
28
29

logger = init_logger(__name__)


30
ClassificationServeContext: TypeAlias = ServeContext[ClassificationRequest]
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56


class ServingClassification(OpenAIServing):
    request_id_prefix = "classify"

    def __init__(
        self,
        engine_client: EngineClient,
        models: OpenAIServingModels,
        *,
        request_logger: RequestLogger | None,
        chat_template: str | None = None,
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
        trust_request_chat_template: bool = False,
        log_error_stack: bool = False,
    ) -> None:
        super().__init__(
            engine_client=engine_client,
            models=models,
            request_logger=request_logger,
            log_error_stack=log_error_stack,
        )

        self.chat_template = chat_template
        self.chat_template_content_format: Final = chat_template_content_format
        self.trust_request_chat_template = trust_request_chat_template
57

58
59
    async def _preprocess(
        self,
60
        ctx: ClassificationServeContext,
61
    ) -> ErrorResponse | None:
62
63
64
65
66
        """
        Process classification inputs: tokenize text, resolve adapters,
        and prepare model-specific inputs.
        """
        try:
67
68
69
70
71
72
73
            ctx.lora_request = self._maybe_get_adapters(ctx.request)

            if isinstance(ctx.request, ClassificationChatRequest):
                error_check_ret = self._validate_chat_template(
                    request_chat_template=ctx.request.chat_template,
                    chat_template_kwargs=ctx.request.chat_template_kwargs,
                    trust_request_chat_template=self.trust_request_chat_template,
74
                )
75
76
                if error_check_ret:
                    return error_check_ret
77

78
                _, ctx.engine_prompts = await self._preprocess_chat(
79
80
                    ctx.request,
                    ctx.request.messages,
81
82
83
                    default_template=self.chat_template,
                    default_template_content_format=self.chat_template_content_format,
                    default_template_kwargs=None,
84
                )
85
            elif isinstance(ctx.request, ClassificationCompletionRequest):
86
87
88
89
                ctx.engine_prompts = await self._preprocess_completion(
                    ctx.request,
                    prompt_input=ctx.request.input,
                    prompt_embeds=None,
90
91
                )
            else:
92
                return self.create_error_response("Invalid classification request type")
93
94
95

            return None

96
        except (ValueError, TypeError, jinja2.TemplateError) as e:
97
98
99
100
101
            logger.exception("Error in preprocessing prompt inputs")
            return self.create_error_response(str(e))

    def _build_response(
        self,
102
        ctx: ClassificationServeContext,
103
    ) -> ClassificationResponse | ErrorResponse:
104
105
106
107
        """
        Convert model outputs to a formatted classification response
        with probabilities and labels.
        """
108
109
        id2label = getattr(self.model_config.hf_config, "id2label", {})

110
111
112
        items: list[ClassificationData] = []
        num_prompt_tokens = 0

113
        final_res_batch_checked = ctx.final_res_batch
114
115
116
117
118
119

        for idx, final_res in enumerate(final_res_batch_checked):
            classify_res = ClassificationOutput.from_base(final_res.outputs)

            probs = classify_res.probs
            predicted_index = int(np.argmax(probs))
120
            label = id2label.get(predicted_index)
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149

            item = ClassificationData(
                index=idx,
                label=label,
                probs=probs,
                num_classes=len(probs),
            )

            items.append(item)
            prompt_token_ids = final_res.prompt_token_ids
            num_prompt_tokens += len(prompt_token_ids)

        usage = UsageInfo(
            prompt_tokens=num_prompt_tokens,
            total_tokens=num_prompt_tokens,
        )

        return ClassificationResponse(
            id=ctx.request_id,
            created=ctx.created_time,
            model=ctx.model_name,
            data=items,
            usage=usage,
        )

    async def create_classify(
        self,
        request: ClassificationRequest,
        raw_request: Request,
150
    ) -> ClassificationResponse | ErrorResponse:
151
        model_name = self.models.model_name()
152
        request_id = f"{self.request_id_prefix}-{self._base_request_id(raw_request)}"
153
154
155
156
157
158
159
160

        ctx = ClassificationServeContext(
            request=request,
            raw_request=raw_request,
            model_name=model_name,
            request_id=request_id,
        )

161
        return await self.handle(ctx)  # type: ignore[return-value]
162
163
164

    def _create_pooling_params(
        self,
165
        ctx: ClassificationServeContext,
166
    ) -> PoolingParams | ErrorResponse:
167
168
169
170
171
172
173
174
175
176
        pooling_params = super()._create_pooling_params(ctx)
        if isinstance(pooling_params, ErrorResponse):
            return pooling_params

        try:
            pooling_params.verify("classify", self.model_config)
        except ValueError as e:
            return self.create_error_response(str(e))

        return pooling_params