serving_models.py 12.4 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
import json
import pathlib
5
6
from asyncio import Lock
from collections import defaultdict
7
8
from dataclasses import dataclass
from http import HTTPStatus
9
from typing import Optional, Union
10
11

from vllm.config import ModelConfig
12
from vllm.engine.protocol import EngineClient
13
from vllm.entrypoints.openai.protocol import (ErrorResponse,
14
                                              LoadLoRAAdapterRequest,
15
16
                                              ModelCard, ModelList,
                                              ModelPermission,
17
                                              UnloadLoRAAdapterRequest)
18
from vllm.logger import init_logger
19
from vllm.lora.request import LoRARequest
20
from vllm.lora.resolver import LoRAResolver, LoRAResolverRegistry
21
22
23
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.utils import AtomicCounter

24
25
logger = init_logger(__name__)

26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56

@dataclass
class BaseModelPath:
    name: str
    model_path: str


@dataclass
class PromptAdapterPath:
    name: str
    local_path: str


@dataclass
class LoRAModulePath:
    name: str
    path: str
    base_model_name: Optional[str] = None


class OpenAIServingModels:
    """Shared instance to hold data about the loaded base model(s) and adapters.

    Handles the routes:
    - /v1/models
    - /v1/load_lora_adapter
    - /v1/unload_lora_adapter
    """

    def __init__(
        self,
57
        engine_client: EngineClient,
58
        model_config: ModelConfig,
59
        base_model_paths: list[BaseModelPath],
60
        *,
61
62
        lora_modules: Optional[list[LoRAModulePath]] = None,
        prompt_adapters: Optional[list[PromptAdapterPath]] = None,
63
64
65
66
67
    ):
        super().__init__()

        self.base_model_paths = base_model_paths
        self.max_model_len = model_config.max_model_len
68
        self.engine_client = engine_client
69
        self.model_config = model_config
70

71
        self.static_lora_modules = lora_modules
72
        self.lora_requests: list[LoRARequest] = []
73
74
        self.lora_id_counter = AtomicCounter(0)

75
76
77
78
79
80
81
        self.lora_resolvers: list[LoRAResolver] = []
        for lora_resolver_name in LoRAResolverRegistry.get_supported_resolvers(
        ):
            self.lora_resolvers.append(
                LoRAResolverRegistry.get_resolver(lora_resolver_name))
        self.lora_resolver_lock: dict[str, Lock] = defaultdict(Lock)

82
83
84
85
86
87
88
89
90
91
92
93
94
95
        self.prompt_adapter_requests = []
        if prompt_adapters is not None:
            for i, prompt_adapter in enumerate(prompt_adapters, start=1):
                with pathlib.Path(prompt_adapter.local_path,
                                  "adapter_config.json").open() as f:
                    adapter_config = json.load(f)
                    num_virtual_tokens = adapter_config["num_virtual_tokens"]
                self.prompt_adapter_requests.append(
                    PromptAdapterRequest(
                        prompt_adapter_name=prompt_adapter.name,
                        prompt_adapter_id=i,
                        prompt_adapter_local_path=prompt_adapter.local_path,
                        prompt_adapter_num_virtual_tokens=num_virtual_tokens))

96
97
98
99
100
101
    async def init_static_loras(self):
        """Loads all static LoRA modules.
        Raises if any fail to load"""
        if self.static_lora_modules is None:
            return
        for lora in self.static_lora_modules:
102
            load_request = LoadLoRAAdapterRequest(lora_path=lora.path,
103
104
105
106
107
108
                                                  lora_name=lora.name)
            load_result = await self.load_lora_adapter(
                request=load_request, base_model_name=lora.base_model_name)
            if isinstance(load_result, ErrorResponse):
                raise ValueError(load_result.message)

109
    def is_base_model(self, model_name) -> bool:
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
        return any(model.name == model_name for model in self.base_model_paths)

    def model_name(self, lora_request: Optional[LoRARequest] = None) -> str:
        """Returns the appropriate model name depending on the availability
        and support of the LoRA or base model.
        Parameters:
        - lora: LoRARequest that contain a base_model_name.
        Returns:
        - str: The name of the base model or the first available model path.
        """
        if lora_request is not None:
            return lora_request.lora_name
        return self.base_model_paths[0].name

    async def show_available_models(self) -> ModelList:
        """Show available models. This includes the base model and all 
        adapters"""
        model_cards = [
            ModelCard(id=base_model.name,
                      max_model_len=self.max_model_len,
                      root=base_model.model_path,
                      permission=[ModelPermission()])
            for base_model in self.base_model_paths
        ]
        lora_cards = [
            ModelCard(id=lora.lora_name,
                      root=lora.local_path,
                      parent=lora.base_model_name if lora.base_model_name else
                      self.base_model_paths[0].name,
                      permission=[ModelPermission()])
            for lora in self.lora_requests
        ]
        prompt_adapter_cards = [
            ModelCard(id=prompt_adapter.prompt_adapter_name,
                      root=self.base_model_paths[0].name,
                      permission=[ModelPermission()])
            for prompt_adapter in self.prompt_adapter_requests
        ]
        model_cards.extend(lora_cards)
        model_cards.extend(prompt_adapter_cards)
        return ModelList(data=model_cards)

    async def load_lora_adapter(
            self,
154
            request: LoadLoRAAdapterRequest,
155
156
            base_model_name: Optional[str] = None
    ) -> Union[ErrorResponse, str]:
157
158
159
160
161
162
        error_check_ret = await self._check_load_lora_adapter_request(request)
        if error_check_ret is not None:
            return error_check_ret

        lora_name, lora_path = request.lora_name, request.lora_path
        unique_id = self.lora_id_counter.inc(1)
163
164
165
166
167
168
169
170
171
172
173
        lora_request = LoRARequest(lora_name=lora_name,
                                   lora_int_id=unique_id,
                                   lora_path=lora_path)
        if base_model_name is not None and self.is_base_model(base_model_name):
            lora_request.base_model_name = base_model_name

        # Validate that the adapter can be loaded into the engine
        # This will also pre-load it for incoming requests
        try:
            await self.engine_client.add_lora(lora_request)
        except BaseException as e:
174
175
            error_type = "BadRequestError"
            status_code = HTTPStatus.BAD_REQUEST
176
            if "No adapter found" in str(e):
177
178
179
                error_type = "NotFoundError"
                status_code = HTTPStatus.NOT_FOUND

180
            return create_error_response(message=str(e),
181
182
                                         err_type=error_type,
                                         status_code=status_code)
183
184
185
186

        self.lora_requests.append(lora_request)
        logger.info("Loaded new LoRA adapter: name '%s', path '%s'", lora_name,
                    lora_path)
187
188
189
190
        return f"Success: LoRA adapter '{lora_name}' added successfully."

    async def unload_lora_adapter(
            self,
191
            request: UnloadLoRAAdapterRequest) -> Union[ErrorResponse, str]:
192
193
194
195
196
197
198
199
200
201
        error_check_ret = await self._check_unload_lora_adapter_request(request
                                                                        )
        if error_check_ret is not None:
            return error_check_ret

        lora_name = request.lora_name
        self.lora_requests = [
            lora_request for lora_request in self.lora_requests
            if lora_request.lora_name != lora_name
        ]
202
        logger.info("Removed LoRA adapter: name '%s'", lora_name)
203
204
205
        return f"Success: LoRA adapter '{lora_name}' removed successfully."

    async def _check_load_lora_adapter_request(
206
            self, request: LoadLoRAAdapterRequest) -> Optional[ErrorResponse]:
207
208
209
210
211
212
213
214
215
216
217
218
        # Check if both 'lora_name' and 'lora_path' are provided
        if not request.lora_name or not request.lora_path:
            return create_error_response(
                message="Both 'lora_name' and 'lora_path' must be provided.",
                err_type="InvalidUserInput",
                status_code=HTTPStatus.BAD_REQUEST)

        # Check if the lora adapter with the given name already exists
        if any(lora_request.lora_name == request.lora_name
               for lora_request in self.lora_requests):
            return create_error_response(
                message=
219
                f"The lora adapter '{request.lora_name}' has already been "
220
221
222
223
224
225
226
227
                "loaded.",
                err_type="InvalidUserInput",
                status_code=HTTPStatus.BAD_REQUEST)

        return None

    async def _check_unload_lora_adapter_request(
            self,
228
            request: UnloadLoRAAdapterRequest) -> Optional[ErrorResponse]:
229
230
231
232
233
234
235
236
237
238
239
240
241
242
        # Check if either 'lora_name' or 'lora_int_id' is provided
        if not request.lora_name and not request.lora_int_id:
            return create_error_response(
                message=
                "either 'lora_name' and 'lora_int_id' needs to be provided.",
                err_type="InvalidUserInput",
                status_code=HTTPStatus.BAD_REQUEST)

        # Check if the lora adapter with the given name exists
        if not any(lora_request.lora_name == request.lora_name
                   for lora_request in self.lora_requests):
            return create_error_response(
                message=
                f"The lora adapter '{request.lora_name}' cannot be found.",
243
244
                err_type="NotFoundError",
                status_code=HTTPStatus.NOT_FOUND)
245
246
247

        return None

248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
    async def resolve_lora(
            self, lora_name: str) -> Union[LoRARequest, ErrorResponse]:
        """Attempt to resolve a LoRA adapter using available resolvers.

        Args:
            lora_name: Name/identifier of the LoRA adapter

        Returns:
            LoRARequest if found and loaded successfully.
            ErrorResponse (404) if no resolver finds the adapter.
            ErrorResponse (400) if adapter(s) are found but none load.
        """
        async with self.lora_resolver_lock[lora_name]:
            # First check if this LoRA is already loaded
            for existing in self.lora_requests:
                if existing.lora_name == lora_name:
                    return existing

            base_model_name = self.model_config.model
            unique_id = self.lora_id_counter.inc(1)
            found_adapter = False

            # Try to resolve using available resolvers
            for resolver in self.lora_resolvers:
                lora_request = await resolver.resolve_lora(
                    base_model_name, lora_name)

                if lora_request is not None:
                    found_adapter = True
                    lora_request.lora_int_id = unique_id

                    try:
                        await self.engine_client.add_lora(lora_request)
                        self.lora_requests.append(lora_request)
                        logger.info(
                            "Resolved and loaded LoRA adapter '%s' using %s",
                            lora_name, resolver.__class__.__name__)
                        return lora_request
                    except BaseException as e:
                        logger.warning(
                            "Failed to load LoRA '%s' resolved by %s: %s. "
                            "Trying next resolver.", lora_name,
                            resolver.__class__.__name__, e)
                        continue

            if found_adapter:
                # An adapter was found, but all attempts to load it failed.
                return create_error_response(
                    message=(f"LoRA adapter '{lora_name}' was found "
                             "but could not be loaded."),
                    err_type="BadRequestError",
                    status_code=HTTPStatus.BAD_REQUEST)
            else:
                # No adapter was found
                return create_error_response(
                    message=f"LoRA adapter {lora_name} does not exist",
                    err_type="NotFoundError",
                    status_code=HTTPStatus.NOT_FOUND)

307
308
309
310
311
312
313
314

def create_error_response(
        message: str,
        err_type: str = "BadRequestError",
        status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse:
    return ErrorResponse(message=message,
                         type=err_type,
                         code=status_code.value)