dynamic_lora.py 2.07 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import model_hosting_container_standards.sagemaker as sagemaker_standards
from fastapi import APIRouter, Depends, Request
from fastapi.responses import JSONResponse, Response

from vllm.entrypoints.openai.api_server import models, validate_json_request
from vllm.entrypoints.openai.protocol import (
    ErrorResponse,
    LoadLoRAAdapterRequest,
    UnloadLoRAAdapterRequest,
)
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.logger import init_logger

logger = init_logger(__name__)


def register_dynamic_lora_routes(router: APIRouter):
    @sagemaker_standards.register_load_adapter_handler(
        request_shape={
            "lora_name": "body.name",
            "lora_path": "body.src",
        },
    )
    @router.post("/v1/load_lora_adapter", dependencies=[Depends(validate_json_request)])
    async def load_lora_adapter(request: LoadLoRAAdapterRequest, raw_request: Request):
        handler: OpenAIServingModels = models(raw_request)
        response = await handler.load_lora_adapter(request)
        if isinstance(response, ErrorResponse):
            return JSONResponse(
                content=response.model_dump(), status_code=response.error.code
            )

        return Response(status_code=200, content=response)

    @sagemaker_standards.register_unload_adapter_handler(
        request_shape={
            "lora_name": "path_params.adapter_name",
        }
    )
    @router.post(
        "/v1/unload_lora_adapter", dependencies=[Depends(validate_json_request)]
    )
    async def unload_lora_adapter(
        request: UnloadLoRAAdapterRequest, raw_request: Request
    ):
        handler: OpenAIServingModels = models(raw_request)
        response = await handler.unload_lora_adapter(request)
        if isinstance(response, ErrorResponse):
            return JSONResponse(
                content=response.model_dump(), status_code=response.error.code
            )

        return Response(status_code=200, content=response)

    return router