serving.py 2.51 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from typing import TypeAlias
5
6

import numpy as np
7
from fastapi.responses import JSONResponse
8
9
10
11

from vllm.config import ModelConfig
from vllm.entrypoints.chat_utils import ChatTemplateConfig
from vllm.entrypoints.openai.engine.protocol import UsageInfo
12
13
from vllm.entrypoints.pooling.base.serving import PoolingServing
from vllm.entrypoints.pooling.typing import PoolingServeContext
14
from vllm.logger import init_logger
15
from vllm.outputs import ClassificationOutput
16
17
18
19
from vllm.renderers import BaseRenderer

from .io_processor import ClassifyIOProcessor
from .protocol import (
20
21
22
23
    ClassificationData,
    ClassificationRequest,
    ClassificationResponse,
)
24
25
26
27

logger = init_logger(__name__)


28
ClassificationServeContext: TypeAlias = PoolingServeContext[ClassificationRequest]
29
30


31
class ServingClassification(PoolingServing):
32
33
    request_id_prefix = "classify"

34
    def init_io_processor(
35
        self,
36
37
38
39
40
41
42
43
        model_config: ModelConfig,
        renderer: BaseRenderer,
        chat_template_config: ChatTemplateConfig,
    ) -> ClassifyIOProcessor:
        return ClassifyIOProcessor(
            model_config=model_config,
            renderer=renderer,
            chat_template_config=chat_template_config,
44
45
        )

46
    async def _build_response(
47
        self,
48
        ctx: ClassificationServeContext,
49
    ) -> JSONResponse:
50
        id2label = getattr(self.model_config.hf_config, "id2label", {})
51
        num_prompt_tokens = 0
52
        items: list[ClassificationData] = []
53
        for idx, final_res in enumerate(ctx.final_res_batch):
54
55
56
57
            classify_res = ClassificationOutput.from_base(final_res.outputs)

            probs = classify_res.probs
            predicted_index = int(np.argmax(probs))
58
            label = id2label.get(predicted_index)
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75

            item = ClassificationData(
                index=idx,
                label=label,
                probs=probs,
                num_classes=len(probs),
            )

            items.append(item)
            prompt_token_ids = final_res.prompt_token_ids
            num_prompt_tokens += len(prompt_token_ids)

        usage = UsageInfo(
            prompt_tokens=num_prompt_tokens,
            total_tokens=num_prompt_tokens,
        )

76
        response = ClassificationResponse(
77
78
79
80
81
82
            id=ctx.request_id,
            created=ctx.created_time,
            model=ctx.model_name,
            data=items,
            usage=usage,
        )
83
84

        return JSONResponse(content=response.model_dump())