serving_models.py 11.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
from asyncio import Lock
from collections import defaultdict
6
7
from dataclasses import dataclass
from http import HTTPStatus
8
from typing import Optional, Union
9
10

from vllm.config import ModelConfig
11
from vllm.engine.protocol import EngineClient
12
from vllm.entrypoints.openai.protocol import (ErrorInfo, ErrorResponse,
13
                                              LoadLoRAAdapterRequest,
14
15
                                              ModelCard, ModelList,
                                              ModelPermission,
16
                                              UnloadLoRAAdapterRequest)
17
from vllm.logger import init_logger
18
from vllm.lora.request import LoRARequest
19
from vllm.lora.resolver import LoRAResolver, LoRAResolverRegistry
20
21
from vllm.utils import AtomicCounter

22
23
logger = init_logger(__name__)

24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48

@dataclass
class BaseModelPath:
    name: str
    model_path: str


@dataclass
class LoRAModulePath:
    name: str
    path: str
    base_model_name: Optional[str] = None


class OpenAIServingModels:
    """Shared instance to hold data about the loaded base model(s) and adapters.

    Handles the routes:
    - /v1/models
    - /v1/load_lora_adapter
    - /v1/unload_lora_adapter
    """

    def __init__(
        self,
49
        engine_client: EngineClient,
50
        model_config: ModelConfig,
51
        base_model_paths: list[BaseModelPath],
52
        *,
53
        lora_modules: Optional[list[LoRAModulePath]] = None,
54
55
56
57
    ):
        super().__init__()

        self.base_model_paths = base_model_paths
58

59
        self.max_model_len = model_config.max_model_len
60
        self.engine_client = engine_client
61
        self.model_config = model_config
62

63
        self.static_lora_modules = lora_modules
64
        self.lora_requests: dict[str, LoRARequest] = {}
65
66
        self.lora_id_counter = AtomicCounter(0)

67
68
69
70
71
72
73
        self.lora_resolvers: list[LoRAResolver] = []
        for lora_resolver_name in LoRAResolverRegistry.get_supported_resolvers(
        ):
            self.lora_resolvers.append(
                LoRAResolverRegistry.get_resolver(lora_resolver_name))
        self.lora_resolver_lock: dict[str, Lock] = defaultdict(Lock)

74
75
76
77
78
79
    async def init_static_loras(self):
        """Loads all static LoRA modules.
        Raises if any fail to load"""
        if self.static_lora_modules is None:
            return
        for lora in self.static_lora_modules:
80
            load_request = LoadLoRAAdapterRequest(lora_path=lora.path,
81
82
83
84
                                                  lora_name=lora.name)
            load_result = await self.load_lora_adapter(
                request=load_request, base_model_name=lora.base_model_name)
            if isinstance(load_result, ErrorResponse):
85
                raise ValueError(load_result.error.message)
86

87
    def is_base_model(self, model_name) -> bool:
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
        return any(model.name == model_name for model in self.base_model_paths)

    def model_name(self, lora_request: Optional[LoRARequest] = None) -> str:
        """Returns the appropriate model name depending on the availability
        and support of the LoRA or base model.
        Parameters:
        - lora: LoRARequest that contain a base_model_name.
        Returns:
        - str: The name of the base model or the first available model path.
        """
        if lora_request is not None:
            return lora_request.lora_name
        return self.base_model_paths[0].name

    async def show_available_models(self) -> ModelList:
        """Show available models. This includes the base model and all 
        adapters"""
        model_cards = [
            ModelCard(id=base_model.name,
                      max_model_len=self.max_model_len,
                      root=base_model.model_path,
                      permission=[ModelPermission()])
            for base_model in self.base_model_paths
        ]
        lora_cards = [
            ModelCard(id=lora.lora_name,
                      root=lora.local_path,
                      parent=lora.base_model_name if lora.base_model_name else
                      self.base_model_paths[0].name,
                      permission=[ModelPermission()])
118
            for lora in self.lora_requests.values()
119
120
121
122
123
124
        ]
        model_cards.extend(lora_cards)
        return ModelList(data=model_cards)

    async def load_lora_adapter(
            self,
125
            request: LoadLoRAAdapterRequest,
126
127
            base_model_name: Optional[str] = None
    ) -> Union[ErrorResponse, str]:
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
        lora_name = request.lora_name

        # Ensure atomicity based on the lora name
        async with self.lora_resolver_lock[lora_name]:
            error_check_ret = await self._check_load_lora_adapter_request(
                request)
            if error_check_ret is not None:
                return error_check_ret

            lora_path = request.lora_path
            unique_id = self.lora_id_counter.inc(1)
            lora_request = LoRARequest(lora_name=lora_name,
                                       lora_int_id=unique_id,
                                       lora_path=lora_path)
            if base_model_name is not None and self.is_base_model(
                    base_model_name):
                lora_request.base_model_name = base_model_name

            # Validate that the adapter can be loaded into the engine
            # This will also pre-load it for incoming requests
            try:
                await self.engine_client.add_lora(lora_request)
            except Exception as e:
                error_type = "BadRequestError"
                status_code = HTTPStatus.BAD_REQUEST
                if "No adapter found" in str(e):
                    error_type = "NotFoundError"
                    status_code = HTTPStatus.NOT_FOUND

                return create_error_response(message=str(e),
                                             err_type=error_type,
                                             status_code=status_code)

            self.lora_requests[lora_name] = lora_request
            logger.info("Loaded new LoRA adapter: name '%s', path '%s'",
                        lora_name, lora_path)
            return f"Success: LoRA adapter '{lora_name}' added successfully."
165
166
167

    async def unload_lora_adapter(
            self,
168
            request: UnloadLoRAAdapterRequest) -> Union[ErrorResponse, str]:
169
        lora_name = request.lora_name
170
171
172
173
174
175
176
177
178
179
180
181

        # Ensure atomicity based on the lora name
        async with self.lora_resolver_lock[lora_name]:
            error_check_ret = await self._check_unload_lora_adapter_request(
                request)
            if error_check_ret is not None:
                return error_check_ret

            # Safe to delete now since we hold the lock
            del self.lora_requests[lora_name]
            logger.info("Removed LoRA adapter: name '%s'", lora_name)
            return f"Success: LoRA adapter '{lora_name}' removed successfully."
182
183

    async def _check_load_lora_adapter_request(
184
            self, request: LoadLoRAAdapterRequest) -> Optional[ErrorResponse]:
185
186
187
188
189
190
191
192
        # Check if both 'lora_name' and 'lora_path' are provided
        if not request.lora_name or not request.lora_path:
            return create_error_response(
                message="Both 'lora_name' and 'lora_path' must be provided.",
                err_type="InvalidUserInput",
                status_code=HTTPStatus.BAD_REQUEST)

        # Check if the lora adapter with the given name already exists
193
        if request.lora_name in self.lora_requests:
194
195
            return create_error_response(
                message=
196
                f"The lora adapter '{request.lora_name}' has already been "
197
198
199
200
201
202
203
204
                "loaded.",
                err_type="InvalidUserInput",
                status_code=HTTPStatus.BAD_REQUEST)

        return None

    async def _check_unload_lora_adapter_request(
            self,
205
            request: UnloadLoRAAdapterRequest) -> Optional[ErrorResponse]:
206
207
        # Check if 'lora_name' is not provided return an error
        if not request.lora_name:
208
209
            return create_error_response(
                message=
210
                "'lora_name' needs to be provided to unload a LoRA adapter.",
211
212
213
214
                err_type="InvalidUserInput",
                status_code=HTTPStatus.BAD_REQUEST)

        # Check if the lora adapter with the given name exists
215
        if request.lora_name not in self.lora_requests:
216
217
218
            return create_error_response(
                message=
                f"The lora adapter '{request.lora_name}' cannot be found.",
219
220
                err_type="NotFoundError",
                status_code=HTTPStatus.NOT_FOUND)
221
222
223

        return None

224
225
226
227
228
229
230
231
232
233
234
235
236
237
    async def resolve_lora(
            self, lora_name: str) -> Union[LoRARequest, ErrorResponse]:
        """Attempt to resolve a LoRA adapter using available resolvers.

        Args:
            lora_name: Name/identifier of the LoRA adapter

        Returns:
            LoRARequest if found and loaded successfully.
            ErrorResponse (404) if no resolver finds the adapter.
            ErrorResponse (400) if adapter(s) are found but none load.
        """
        async with self.lora_resolver_lock[lora_name]:
            # First check if this LoRA is already loaded
238
239
            if lora_name in self.lora_requests:
                return self.lora_requests[lora_name]
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255

            base_model_name = self.model_config.model
            unique_id = self.lora_id_counter.inc(1)
            found_adapter = False

            # Try to resolve using available resolvers
            for resolver in self.lora_resolvers:
                lora_request = await resolver.resolve_lora(
                    base_model_name, lora_name)

                if lora_request is not None:
                    found_adapter = True
                    lora_request.lora_int_id = unique_id

                    try:
                        await self.engine_client.add_lora(lora_request)
256
                        self.lora_requests[lora_name] = lora_request
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
                        logger.info(
                            "Resolved and loaded LoRA adapter '%s' using %s",
                            lora_name, resolver.__class__.__name__)
                        return lora_request
                    except BaseException as e:
                        logger.warning(
                            "Failed to load LoRA '%s' resolved by %s: %s. "
                            "Trying next resolver.", lora_name,
                            resolver.__class__.__name__, e)
                        continue

            if found_adapter:
                # An adapter was found, but all attempts to load it failed.
                return create_error_response(
                    message=(f"LoRA adapter '{lora_name}' was found "
                             "but could not be loaded."),
                    err_type="BadRequestError",
                    status_code=HTTPStatus.BAD_REQUEST)
            else:
                # No adapter was found
                return create_error_response(
                    message=f"LoRA adapter {lora_name} does not exist",
                    err_type="NotFoundError",
                    status_code=HTTPStatus.NOT_FOUND)

282
283
284
285
286

def create_error_response(
        message: str,
        err_type: str = "BadRequestError",
        status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse:
287
288
    return ErrorResponse(error=ErrorInfo(
        message=message, type=err_type, code=status_code.value))