serving.py 2.49 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from typing import TypeAlias
5
6

import numpy as np
7
8
9
10
11
12
13
14
15
16
17

from vllm import ClassificationOutput
from vllm.config import ModelConfig
from vllm.entrypoints.chat_utils import ChatTemplateConfig
from vllm.entrypoints.openai.engine.protocol import UsageInfo
from vllm.entrypoints.pooling.base.serving import PoolingServeContext, PoolingServing
from vllm.logger import init_logger
from vllm.renderers import BaseRenderer

from .io_processor import ClassifyIOProcessor
from .protocol import (
18
19
20
21
    ClassificationData,
    ClassificationRequest,
    ClassificationResponse,
)
22
23
24
25

logger = init_logger(__name__)


26
ClassificationServeContext: TypeAlias = PoolingServeContext[ClassificationRequest]
27
28


29
class ServingClassification(PoolingServing):
30
31
    request_id_prefix = "classify"

32
    def init_io_processor(
33
        self,
34
35
36
37
38
39
40
41
        model_config: ModelConfig,
        renderer: BaseRenderer,
        chat_template_config: ChatTemplateConfig,
    ) -> ClassifyIOProcessor:
        return ClassifyIOProcessor(
            model_config=model_config,
            renderer=renderer,
            chat_template_config=chat_template_config,
42
43
        )

44
    async def _build_response(
45
        self,
46
        ctx: ClassificationServeContext,
47
48
49
50
    ) -> ClassificationResponse:
        final_res_batch_checked = await self.io_processor.post_process_async(
            ctx.final_res_batch
        )
51

52
        id2label = getattr(self.model_config.hf_config, "id2label", {})
53
        num_prompt_tokens = 0
54
        items: list[ClassificationData] = []
55
56
57
58
59
        for idx, final_res in enumerate(final_res_batch_checked):
            classify_res = ClassificationOutput.from_base(final_res.outputs)

            probs = classify_res.probs
            predicted_index = int(np.argmax(probs))
60
            label = id2label.get(predicted_index)
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84

            item = ClassificationData(
                index=idx,
                label=label,
                probs=probs,
                num_classes=len(probs),
            )

            items.append(item)
            prompt_token_ids = final_res.prompt_token_ids
            num_prompt_tokens += len(prompt_token_ids)

        usage = UsageInfo(
            prompt_tokens=num_prompt_tokens,
            total_tokens=num_prompt_tokens,
        )

        return ClassificationResponse(
            id=ctx.request_id,
            created=ctx.created_time,
            model=ctx.model_name,
            data=items,
            usage=usage,
        )