serving.py 5.45 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from typing import Final, TypeAlias
5

6
import jinja2
7
8
9
10
import numpy as np
from fastapi import Request

from vllm.engine.protocol import EngineClient
11
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
12
from vllm.entrypoints.logger import RequestLogger
13
14
from vllm.entrypoints.openai.engine.protocol import ErrorResponse, UsageInfo
from vllm.entrypoints.openai.engine.serving import OpenAIServing, ServeContext
15
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
16
17
18
19
20
21
22
from vllm.entrypoints.pooling.classify.protocol import (
    ClassificationChatRequest,
    ClassificationCompletionRequest,
    ClassificationData,
    ClassificationRequest,
    ClassificationResponse,
)
23
from vllm.logger import init_logger
24
from vllm.outputs import ClassificationOutput
25
26
27
28

logger = init_logger(__name__)


29
ClassificationServeContext: TypeAlias = ServeContext[ClassificationRequest]
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55


class ServingClassification(OpenAIServing):
    request_id_prefix = "classify"

    def __init__(
        self,
        engine_client: EngineClient,
        models: OpenAIServingModels,
        *,
        request_logger: RequestLogger | None,
        chat_template: str | None = None,
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
        trust_request_chat_template: bool = False,
        log_error_stack: bool = False,
    ) -> None:
        super().__init__(
            engine_client=engine_client,
            models=models,
            request_logger=request_logger,
            log_error_stack=log_error_stack,
        )

        self.chat_template = chat_template
        self.chat_template_content_format: Final = chat_template_content_format
        self.trust_request_chat_template = trust_request_chat_template
56

57
58
    async def _preprocess(
        self,
59
        ctx: ClassificationServeContext,
60
    ) -> ErrorResponse | None:
61
62
63
64
65
        """
        Process classification inputs: tokenize text, resolve adapters,
        and prepare model-specific inputs.
        """
        try:
66
67
68
69
70
71
72
            ctx.lora_request = self._maybe_get_adapters(ctx.request)

            if isinstance(ctx.request, ClassificationChatRequest):
                error_check_ret = self._validate_chat_template(
                    request_chat_template=ctx.request.chat_template,
                    chat_template_kwargs=ctx.request.chat_template_kwargs,
                    trust_request_chat_template=self.trust_request_chat_template,
73
                )
74
75
                if error_check_ret:
                    return error_check_ret
76

77
                _, ctx.engine_prompts = await self._preprocess_chat(
78
79
                    ctx.request,
                    ctx.request.messages,
80
81
82
                    default_template=self.chat_template,
                    default_template_content_format=self.chat_template_content_format,
                    default_template_kwargs=None,
83
                )
84
            elif isinstance(ctx.request, ClassificationCompletionRequest):
85
86
87
88
                ctx.engine_prompts = await self._preprocess_completion(
                    ctx.request,
                    prompt_input=ctx.request.input,
                    prompt_embeds=None,
89
90
                )
            else:
91
                return self.create_error_response("Invalid classification request type")
92
93
94

            return None

95
        except (ValueError, TypeError, jinja2.TemplateError) as e:
96
97
98
99
100
            logger.exception("Error in preprocessing prompt inputs")
            return self.create_error_response(str(e))

    def _build_response(
        self,
101
        ctx: ClassificationServeContext,
102
    ) -> ClassificationResponse | ErrorResponse:
103
104
105
106
        """
        Convert model outputs to a formatted classification response
        with probabilities and labels.
        """
107
108
        id2label = getattr(self.model_config.hf_config, "id2label", {})

109
110
111
        items: list[ClassificationData] = []
        num_prompt_tokens = 0

112
        final_res_batch_checked = ctx.final_res_batch
113
114
115
116
117
118

        for idx, final_res in enumerate(final_res_batch_checked):
            classify_res = ClassificationOutput.from_base(final_res.outputs)

            probs = classify_res.probs
            predicted_index = int(np.argmax(probs))
119
            label = id2label.get(predicted_index)
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148

            item = ClassificationData(
                index=idx,
                label=label,
                probs=probs,
                num_classes=len(probs),
            )

            items.append(item)
            prompt_token_ids = final_res.prompt_token_ids
            num_prompt_tokens += len(prompt_token_ids)

        usage = UsageInfo(
            prompt_tokens=num_prompt_tokens,
            total_tokens=num_prompt_tokens,
        )

        return ClassificationResponse(
            id=ctx.request_id,
            created=ctx.created_time,
            model=ctx.model_name,
            data=items,
            usage=usage,
        )

    async def create_classify(
        self,
        request: ClassificationRequest,
        raw_request: Request,
149
    ) -> ClassificationResponse | ErrorResponse:
150
        model_name = self.models.model_name()
151
        request_id = f"{self.request_id_prefix}-{self._base_request_id(raw_request)}"
152
153
154
155
156
157
158
159

        ctx = ClassificationServeContext(
            request=request,
            raw_request=raw_request,
            model_name=model_name,
            request_id=request_id,
        )

160
        return await self.handle(ctx)  # type: ignore[return-value]