serving.py 2.15 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from typing import TypeAlias
5
6

import numpy as np
7
from fastapi.responses import JSONResponse
8
9

from vllm.entrypoints.openai.engine.protocol import UsageInfo
10
11
from vllm.entrypoints.pooling.base.serving import PoolingServing
from vllm.entrypoints.pooling.typing import PoolingServeContext
12
from vllm.logger import init_logger
13
from vllm.outputs import ClassificationOutput
14
15
16

from .io_processor import ClassifyIOProcessor
from .protocol import (
17
18
19
20
    ClassificationData,
    ClassificationRequest,
    ClassificationResponse,
)
21
22
23
24

logger = init_logger(__name__)


25
ClassificationServeContext: TypeAlias = PoolingServeContext[ClassificationRequest]
26
27


28
class ServingClassification(PoolingServing):
29
30
    request_id_prefix = "classify"

31
32
    def init_io_processor(self, *args, **kwargs) -> ClassifyIOProcessor:
        return ClassifyIOProcessor(*args, **kwargs)
33

34
    def _build_response(
35
        self,
36
        ctx: ClassificationServeContext,
37
    ) -> JSONResponse:
38
        id2label = getattr(self.model_config.hf_config, "id2label", {})
39
        num_prompt_tokens = 0
40
        items: list[ClassificationData] = []
41
        for idx, final_res in enumerate(ctx.final_res_batch):
42
43
44
45
            classify_res = ClassificationOutput.from_base(final_res.outputs)

            probs = classify_res.probs
            predicted_index = int(np.argmax(probs))
46
            label = id2label.get(predicted_index)
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63

            item = ClassificationData(
                index=idx,
                label=label,
                probs=probs,
                num_classes=len(probs),
            )

            items.append(item)
            prompt_token_ids = final_res.prompt_token_ids
            num_prompt_tokens += len(prompt_token_ids)

        usage = UsageInfo(
            prompt_tokens=num_prompt_tokens,
            total_tokens=num_prompt_tokens,
        )

64
        response = ClassificationResponse(
65
66
67
68
69
70
            id=ctx.request_id,
            created=ctx.created_time,
            model=ctx.model_name,
            data=items,
            usage=usage,
        )
71
72

        return JSONResponse(content=response.model_dump())