"docs/community/sponsors.md" did not exist on "856c990041bf6cf4b2397401d4b18531382ecb50"
serving_models.py 11.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
from asyncio import Lock
from collections import defaultdict
6
7
8
from dataclasses import dataclass
from http import HTTPStatus

9
from vllm.engine.protocol import EngineClient
10
from vllm.entrypoints.openai.engine.protocol import (
11
12
13
14
15
    ErrorInfo,
    ErrorResponse,
    ModelCard,
    ModelList,
    ModelPermission,
16
17
18
)
from vllm.entrypoints.serve.lora.protocol import (
    LoadLoRAAdapterRequest,
19
20
    UnloadLoRAAdapterRequest,
)
21
from vllm.entrypoints.utils import sanitize_message
22
from vllm.logger import init_logger
23
from vllm.lora.request import LoRARequest
24
from vllm.lora.resolver import LoRAResolver, LoRAResolverRegistry
25
from vllm.utils.counter import AtomicCounter
26

27
28
logger = init_logger(__name__)

29
30
31
32
33
34
35
36
37
38
39

@dataclass
class BaseModelPath:
    name: str
    model_path: str


@dataclass
class LoRAModulePath:
    name: str
    path: str
40
    base_model_name: str | None = None
41
42
43
44
45
46
47
48
49
50
51
52
53


class OpenAIServingModels:
    """Shared instance to hold data about the loaded base model(s) and adapters.

    Handles the routes:
    - /v1/models
    - /v1/load_lora_adapter
    - /v1/unload_lora_adapter
    """

    def __init__(
        self,
54
        engine_client: EngineClient,
55
        base_model_paths: list[BaseModelPath],
56
        *,
57
        lora_modules: list[LoRAModulePath] | None = None,
58
59
60
    ):
        super().__init__()

61
        self.engine_client = engine_client
62
        self.base_model_paths = base_model_paths
63

64
        self.static_lora_modules = lora_modules
65
        self.lora_requests: dict[str, LoRARequest] = {}
66
67
        self.lora_id_counter = AtomicCounter(0)

68
        self.lora_resolvers: list[LoRAResolver] = []
69
        for lora_resolver_name in LoRAResolverRegistry.get_supported_resolvers():
70
            self.lora_resolvers.append(
71
72
                LoRAResolverRegistry.get_resolver(lora_resolver_name)
            )
73
74
        self.lora_resolver_lock: dict[str, Lock] = defaultdict(Lock)

75
        self.input_processor = self.engine_client.input_processor
76
77
78
79
        self.io_processor = self.engine_client.io_processor
        self.model_config = self.engine_client.model_config
        self.max_model_len = self.model_config.max_model_len

80
81
82
83
84
85
    async def init_static_loras(self):
        """Loads all static LoRA modules.
        Raises if any fail to load"""
        if self.static_lora_modules is None:
            return
        for lora in self.static_lora_modules:
86
87
88
            load_request = LoadLoRAAdapterRequest(
                lora_path=lora.path, lora_name=lora.name
            )
89
            load_result = await self.load_lora_adapter(
90
91
                request=load_request, base_model_name=lora.base_model_name
            )
92
            if isinstance(load_result, ErrorResponse):
93
                raise ValueError(load_result.error.message)
94

95
    def is_base_model(self, model_name) -> bool:
96
97
        return any(model.name == model_name for model in self.base_model_paths)

98
    def model_name(self, lora_request: LoRARequest | None = None) -> str:
99
100
101
102
103
104
105
106
107
108
109
110
        """Returns the appropriate model name depending on the availability
        and support of the LoRA or base model.
        Parameters:
        - lora: LoRARequest that contain a base_model_name.
        Returns:
        - str: The name of the base model or the first available model path.
        """
        if lora_request is not None:
            return lora_request.lora_name
        return self.base_model_paths[0].name

    async def show_available_models(self) -> ModelList:
111
        """Show available models. This includes the base model and all
112
113
        adapters"""
        model_cards = [
114
115
116
117
118
119
            ModelCard(
                id=base_model.name,
                max_model_len=self.max_model_len,
                root=base_model.model_path,
                permission=[ModelPermission()],
            )
120
121
122
            for base_model in self.base_model_paths
        ]
        lora_cards = [
123
124
            ModelCard(
                id=lora.lora_name,
125
                root=lora.path,
126
127
128
129
130
                parent=lora.base_model_name
                if lora.base_model_name
                else self.base_model_paths[0].name,
                permission=[ModelPermission()],
            )
131
            for lora in self.lora_requests.values()
132
133
134
135
136
        ]
        model_cards.extend(lora_cards)
        return ModelList(data=model_cards)

    async def load_lora_adapter(
137
138
        self, request: LoadLoRAAdapterRequest, base_model_name: str | None = None
    ) -> ErrorResponse | str:
139
140
141
142
        lora_name = request.lora_name

        # Ensure atomicity based on the lora name
        async with self.lora_resolver_lock[lora_name]:
143
            error_check_ret = await self._check_load_lora_adapter_request(request)
144
145
146
147
148
            if error_check_ret is not None:
                return error_check_ret

            lora_path = request.lora_path
            unique_id = self.lora_id_counter.inc(1)
149
150
151
152
            lora_request = LoRARequest(
                lora_name=lora_name, lora_int_id=unique_id, lora_path=lora_path
            )
            if base_model_name is not None and self.is_base_model(base_model_name):
153
154
155
                lora_request.base_model_name = base_model_name

            # Validate that the adapter can be loaded into the engine
156
            # This will also preload it for incoming requests
157
158
159
160
161
162
163
164
165
            try:
                await self.engine_client.add_lora(lora_request)
            except Exception as e:
                error_type = "BadRequestError"
                status_code = HTTPStatus.BAD_REQUEST
                if "No adapter found" in str(e):
                    error_type = "NotFoundError"
                    status_code = HTTPStatus.NOT_FOUND

166
167
168
                return create_error_response(
                    message=str(e), err_type=error_type, status_code=status_code
                )
169
170

            self.lora_requests[lora_name] = lora_request
171
172
173
            logger.info(
                "Loaded new LoRA adapter: name '%s', path '%s'", lora_name, lora_path
            )
174
            return f"Success: LoRA adapter '{lora_name}' added successfully."
175
176

    async def unload_lora_adapter(
177
        self, request: UnloadLoRAAdapterRequest
178
    ) -> ErrorResponse | str:
179
        lora_name = request.lora_name
180
181
182

        # Ensure atomicity based on the lora name
        async with self.lora_resolver_lock[lora_name]:
183
            error_check_ret = await self._check_unload_lora_adapter_request(request)
184
185
186
187
188
189
190
            if error_check_ret is not None:
                return error_check_ret

            # Safe to delete now since we hold the lock
            del self.lora_requests[lora_name]
            logger.info("Removed LoRA adapter: name '%s'", lora_name)
            return f"Success: LoRA adapter '{lora_name}' removed successfully."
191
192

    async def _check_load_lora_adapter_request(
193
        self, request: LoadLoRAAdapterRequest
194
    ) -> ErrorResponse | None:
195
196
197
198
199
        # Check if both 'lora_name' and 'lora_path' are provided
        if not request.lora_name or not request.lora_path:
            return create_error_response(
                message="Both 'lora_name' and 'lora_path' must be provided.",
                err_type="InvalidUserInput",
200
201
                status_code=HTTPStatus.BAD_REQUEST,
            )
202
203

        # Check if the lora adapter with the given name already exists
204
        if request.lora_name in self.lora_requests:
205
            return create_error_response(
206
                message=f"The lora adapter '{request.lora_name}' has already been "
207
208
                "loaded.",
                err_type="InvalidUserInput",
209
210
                status_code=HTTPStatus.BAD_REQUEST,
            )
211
212
213
214

        return None

    async def _check_unload_lora_adapter_request(
215
        self, request: UnloadLoRAAdapterRequest
216
    ) -> ErrorResponse | None:
217
218
        # Check if 'lora_name' is not provided return an error
        if not request.lora_name:
219
            return create_error_response(
220
                message="'lora_name' needs to be provided to unload a LoRA adapter.",
221
                err_type="InvalidUserInput",
222
223
                status_code=HTTPStatus.BAD_REQUEST,
            )
224
225

        # Check if the lora adapter with the given name exists
226
        if request.lora_name not in self.lora_requests:
227
            return create_error_response(
228
                message=f"The lora adapter '{request.lora_name}' cannot be found.",
229
                err_type="NotFoundError",
230
231
                status_code=HTTPStatus.NOT_FOUND,
            )
232
233
234

        return None

235
    async def resolve_lora(self, lora_name: str) -> LoRARequest | ErrorResponse:
236
237
238
239
240
241
242
243
244
245
246
247
        """Attempt to resolve a LoRA adapter using available resolvers.

        Args:
            lora_name: Name/identifier of the LoRA adapter

        Returns:
            LoRARequest if found and loaded successfully.
            ErrorResponse (404) if no resolver finds the adapter.
            ErrorResponse (400) if adapter(s) are found but none load.
        """
        async with self.lora_resolver_lock[lora_name]:
            # First check if this LoRA is already loaded
248
249
            if lora_name in self.lora_requests:
                return self.lora_requests[lora_name]
250
251
252
253
254
255
256

            base_model_name = self.model_config.model
            unique_id = self.lora_id_counter.inc(1)
            found_adapter = False

            # Try to resolve using available resolvers
            for resolver in self.lora_resolvers:
257
                lora_request = await resolver.resolve_lora(base_model_name, lora_name)
258
259
260
261
262
263
264

                if lora_request is not None:
                    found_adapter = True
                    lora_request.lora_int_id = unique_id

                    try:
                        await self.engine_client.add_lora(lora_request)
265
                        self.lora_requests[lora_name] = lora_request
266
267
                        logger.info(
                            "Resolved and loaded LoRA adapter '%s' using %s",
268
269
270
                            lora_name,
                            resolver.__class__.__name__,
                        )
271
272
273
274
                        return lora_request
                    except BaseException as e:
                        logger.warning(
                            "Failed to load LoRA '%s' resolved by %s: %s. "
275
276
277
278
279
                            "Trying next resolver.",
                            lora_name,
                            resolver.__class__.__name__,
                            e,
                        )
280
281
282
283
284
                        continue

            if found_adapter:
                # An adapter was found, but all attempts to load it failed.
                return create_error_response(
285
286
287
                    message=(
                        f"LoRA adapter '{lora_name}' was found but could not be loaded."
                    ),
288
                    err_type="BadRequestError",
289
290
                    status_code=HTTPStatus.BAD_REQUEST,
                )
291
292
293
294
295
            else:
                # No adapter was found
                return create_error_response(
                    message=f"LoRA adapter {lora_name} does not exist",
                    err_type="NotFoundError",
296
297
                    status_code=HTTPStatus.NOT_FOUND,
                )
298

299
300

def create_error_response(
301
302
303
304
305
    message: str,
    err_type: str = "BadRequestError",
    status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
) -> ErrorResponse:
    return ErrorResponse(
306
307
308
309
310
        error=ErrorInfo(
            message=sanitize_message(message),
            type=err_type,
            code=status_code.value,
        )
311
    )