serving_engine.py 9.48 KB
Newer Older
1
import json
2
from dataclasses import dataclass
3
from http import HTTPStatus
4
from typing import Any, Dict, List, Optional, Tuple, Union
5

6
7
from pydantic import Field
from typing_extensions import Annotated
8

9
from vllm.config import ModelConfig
10
from vllm.engine.async_llm_engine import AsyncLLMEngine
11
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
12
13
                                              CompletionRequest,
                                              EmbeddingRequest, ErrorResponse,
14
                                              LogProbs, ModelCard, ModelList,
15
                                              ModelPermission)
16
from vllm.logger import init_logger
17
from vllm.lora.request import LoRARequest
18
from vllm.sequence import Logprob
19
from vllm.transformers_utils.tokenizer import get_tokenizer
20
21
22
23

logger = init_logger(__name__)


24
@dataclass
25
class LoRAModulePath:
26
27
28
29
    name: str
    local_path: str


30
31
class OpenAIServing:

32
    def __init__(self, engine: AsyncLLMEngine, model_config: ModelConfig,
33
                 served_model_names: List[str],
34
35
36
                 lora_modules: Optional[List[LoRAModulePath]]):
        super().__init__()

37
        self.engine = engine
38
39
40
41
42
43
44
45
46
47
        self.max_model_len = model_config.max_model_len

        # A separate tokenizer to map token IDs to strings.
        self.tokenizer = get_tokenizer(
            model_config.tokenizer,
            tokenizer_mode=model_config.tokenizer_mode,
            tokenizer_revision=model_config.tokenizer_revision,
            trust_remote_code=model_config.trust_remote_code,
            truncation_side="left")

48
        self.served_model_names = served_model_names
49

50
51
52
53
54
55
56
57
58
59
        if lora_modules is None:
            self.lora_requests = []
        else:
            self.lora_requests = [
                LoRARequest(
                    lora_name=lora.name,
                    lora_int_id=i,
                    lora_local_path=lora.local_path,
                ) for i, lora in enumerate(lora_modules, start=1)
            ]
60
61
62
63

    async def show_available_models(self) -> ModelList:
        """Show available models. Right now we only have one model."""
        model_cards = [
64
65
            ModelCard(id=served_model_name,
                      root=self.served_model_names[0],
66
                      permission=[ModelPermission()])
67
            for served_model_name in self.served_model_names
68
        ]
69
70
        lora_cards = [
            ModelCard(id=lora.lora_name,
71
                      root=self.served_model_names[0],
72
73
74
75
                      permission=[ModelPermission()])
            for lora in self.lora_requests
        ]
        model_cards.extend(lora_cards)
76
77
78
79
80
        return ModelList(data=model_cards)

    def _create_logprobs(
        self,
        token_ids: List[int],
81
        top_logprobs: List[Optional[Dict[int, Logprob]]],
82
83
84
85
86
87
88
89
        num_output_top_logprobs: Optional[int] = None,
        initial_text_offset: int = 0,
    ) -> LogProbs:
        """Create OpenAI-style logprobs."""
        logprobs = LogProbs()
        last_token_len = 0
        if num_output_top_logprobs:
            logprobs.top_logprobs = []
90

91
92
        for i, token_id in enumerate(token_ids):
            step_top_logprobs = top_logprobs[i]
93
94
95
96
            if step_top_logprobs is None:
                token = self.tokenizer.decode(token_id)
                logprobs.tokens.append(token)
                logprobs.token_logprobs.append(None)
97
                assert logprobs.top_logprobs is not None
98
                logprobs.top_logprobs.append(None)
99
            else:
100
101
102
                token_logprob = step_top_logprobs[token_id].logprob
                token = step_top_logprobs[token_id].decoded_token
                logprobs.tokens.append(token)
103
                token_logprob = max(token_logprob, -9999.0)
104
105
106
                logprobs.token_logprobs.append(token_logprob)

                if num_output_top_logprobs:
107
                    assert logprobs.top_logprobs is not None
108
                    logprobs.top_logprobs.append({
109
110
111
                        # Convert float("-inf") to the
                        # JSON-serializable float that OpenAI uses
                        p.decoded_token: max(p.logprob, -9999.0)
112
113
114
                        for i, p in step_top_logprobs.items()
                    } if step_top_logprobs else None)

115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
            if len(logprobs.text_offset) == 0:
                logprobs.text_offset.append(initial_text_offset)
            else:
                logprobs.text_offset.append(logprobs.text_offset[-1] +
                                            last_token_len)
            last_token_len = len(token)
        return logprobs

    def create_error_response(
            self,
            message: str,
            err_type: str = "BadRequestError",
            status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse:
        return ErrorResponse(message=message,
                             type=err_type,
                             code=status_code.value)

132
133
134
135
136
137
138
139
140
141
142
143
144
    def create_streaming_error_response(
            self,
            message: str,
            err_type: str = "BadRequestError",
            status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> str:
        json_str = json.dumps({
            "error":
            self.create_error_response(message=message,
                                       err_type=err_type,
                                       status_code=status_code).model_dump()
        })
        return json_str

145
146
147
    async def _check_model(
        self, request: Union[CompletionRequest, ChatCompletionRequest]
    ) -> Optional[ErrorResponse]:
148
        if request.model in self.served_model_names:
149
            return None
150
        if request.model in [lora.lora_name for lora in self.lora_requests]:
151
            return None
152
153
154
155
156
        return self.create_error_response(
            message=f"The model `{request.model}` does not exist.",
            err_type="NotFoundError",
            status_code=HTTPStatus.NOT_FOUND)

157
158
159
    def _maybe_get_lora(
        self, request: Union[CompletionRequest, ChatCompletionRequest]
    ) -> Optional[LoRARequest]:
160
        if request.model in self.served_model_names:
161
            return None
162
163
164
165
        for lora in self.lora_requests:
            if request.model == lora.lora_name:
                return lora
        # if _check_model has been called earlier, this will be unreachable
166
        raise ValueError(f"The model `{request.model}` does not exist.")
167

168
    def _validate_prompt_and_tokenize(
169
170
171
172
173
174
175
176
            self,
            request: Union[ChatCompletionRequest, CompletionRequest,
                           EmbeddingRequest],
            prompt: Optional[str] = None,
            prompt_ids: Optional[List[int]] = None,
            truncate_prompt_tokens: Optional[Annotated[int,
                                                       Field(ge=1)]] = None,
            add_special_tokens: bool = True) -> Tuple[List[int], str]:
177
178
179
180
181
182
        if not (prompt or prompt_ids):
            raise ValueError("Either prompt or prompt_ids should be provided.")
        if (prompt and prompt_ids):
            raise ValueError(
                "Only one of prompt or prompt_ids should be provided.")

183
        if prompt_ids is None:
184
185
186
187
188
189
190
            # When using OpenAIServingChat for chat completions, the
            # special tokens (e.g., BOS) have already been added by the
            # chat template. Therefore, we do not need to add them again.
            # Set add_special_tokens to False to avoid adding the BOS tokens
            # again.
            tokenizer_kwargs: Dict[str, Any] = {
                "add_special_tokens": add_special_tokens
191
            }
192
193
194
195
196
            if truncate_prompt_tokens is not None:
                tokenizer_kwargs.update({
                    "truncation": True,
                    "max_length": truncate_prompt_tokens,
                })
197
198
199
200
201
202
            input_ids = self.tokenizer(prompt, **tokenizer_kwargs).input_ids
        elif truncate_prompt_tokens is not None:
            input_ids = prompt_ids[-truncate_prompt_tokens:]
        else:
            input_ids = prompt_ids

203
204
        input_text = prompt if prompt is not None else self.tokenizer.decode(
            prompt_ids)
205
206
        token_num = len(input_ids)

207
208
209
210
211
212
213
214
215
216
        # Note: EmbeddingRequest doesn't have max_tokens
        if isinstance(request, EmbeddingRequest):
            if token_num > self.max_model_len:
                raise ValueError(
                    f"This model's maximum context length is "
                    f"{self.max_model_len} tokens. However, you requested "
                    f"{token_num} tokens in the input for embedding "
                    f"generation. Please reduce the length of the input.", )
            return input_ids, input_text

217
        if request.max_tokens is None:
218
219
220
221
222
223
            if token_num >= self.max_model_len:
                raise ValueError(
                    f"This model's maximum context length is "
                    f"{self.max_model_len} tokens. However, you requested "
                    f"{token_num} tokens in the messages, "
                    f"Please reduce the length of the messages.", )
224
            request.max_tokens = self.max_model_len - token_num
225

226
        if token_num + request.max_tokens > self.max_model_len:
227
            raise ValueError(
228
229
230
                f"This model's maximum context length is "
                f"{self.max_model_len} tokens. However, you requested "
                f"{request.max_tokens + token_num} tokens "
231
232
233
234
                f"({token_num} in the messages, "
                f"{request.max_tokens} in the completion). "
                f"Please reduce the length of the messages or completion.", )
        else:
235
            return input_ids, input_text