test_serving_models.py 5.18 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
from http import HTTPStatus
from unittest.mock import MagicMock

import pytest

from vllm.config import ModelConfig
10
from vllm.engine.protocol import EngineClient
11
from vllm.entrypoints.openai.protocol import (ErrorResponse,
12
13
                                              LoadLoRAAdapterRequest,
                                              UnloadLoRAAdapterRequest)
14
15
from vllm.entrypoints.openai.serving_models import (BaseModelPath,
                                                    OpenAIServingModels)
16
from vllm.lora.request import LoRARequest
17

18
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
19
BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)]
20
21
22
23
24
25
LORA_LOADING_SUCCESS_MESSAGE = (
    "Success: LoRA adapter '{lora_name}' added successfully.")
LORA_UNLOADING_SUCCESS_MESSAGE = (
    "Success: LoRA adapter '{lora_name}' removed successfully.")


26
async def _async_serving_models_init() -> OpenAIServingModels:
27
    mock_model_config = MagicMock(spec=ModelConfig)
28
    mock_engine_client = MagicMock(spec=EngineClient)
29
30
31
    # Set the max_model_len attribute to avoid missing attribute
    mock_model_config.max_model_len = 2048

32
33
    serving_models = OpenAIServingModels(engine_client=mock_engine_client,
                                         base_model_paths=BASE_MODEL_PATHS,
34
                                         model_config=mock_model_config,
35
                                         lora_modules=None)
36
    await serving_models.init_static_loras()
37
38

    return serving_models
39
40


41
42
@pytest.mark.asyncio
async def test_serving_model_name():
43
44
    serving_models = await _async_serving_models_init()
    assert serving_models.model_name(None) == MODEL_NAME
45
46
47
    request = LoRARequest(lora_name="adapter",
                          lora_path="/path/to/adapter2",
                          lora_int_id=1)
48
    assert serving_models.model_name(request) == request.lora_name
49
50


51
52
@pytest.mark.asyncio
async def test_load_lora_adapter_success():
53
    serving_models = await _async_serving_models_init()
54
    request = LoadLoRAAdapterRequest(lora_name="adapter",
55
                                     lora_path="/path/to/adapter2")
56
    response = await serving_models.load_lora_adapter(request)
57
    assert response == LORA_LOADING_SUCCESS_MESSAGE.format(lora_name='adapter')
58
    assert len(serving_models.lora_requests) == 1
59
60
    assert "adapter" in serving_models.lora_requests
    assert serving_models.lora_requests["adapter"].lora_name == "adapter"
61
62
63
64


@pytest.mark.asyncio
async def test_load_lora_adapter_missing_fields():
65
    serving_models = await _async_serving_models_init()
66
    request = LoadLoRAAdapterRequest(lora_name="", lora_path="")
67
    response = await serving_models.load_lora_adapter(request)
68
    assert isinstance(response, ErrorResponse)
69
70
    assert response.error.type == "InvalidUserInput"
    assert response.error.code == HTTPStatus.BAD_REQUEST
71
72
73
74


@pytest.mark.asyncio
async def test_load_lora_adapter_duplicate():
75
    serving_models = await _async_serving_models_init()
76
    request = LoadLoRAAdapterRequest(lora_name="adapter1",
77
                                     lora_path="/path/to/adapter1")
78
    response = await serving_models.load_lora_adapter(request)
79
80
    assert response == LORA_LOADING_SUCCESS_MESSAGE.format(
        lora_name='adapter1')
81
    assert len(serving_models.lora_requests) == 1
82

83
    request = LoadLoRAAdapterRequest(lora_name="adapter1",
84
                                     lora_path="/path/to/adapter1")
85
    response = await serving_models.load_lora_adapter(request)
86
    assert isinstance(response, ErrorResponse)
87
88
    assert response.error.type == "InvalidUserInput"
    assert response.error.code == HTTPStatus.BAD_REQUEST
89
    assert len(serving_models.lora_requests) == 1
90
91
92
93


@pytest.mark.asyncio
async def test_unload_lora_adapter_success():
94
    serving_models = await _async_serving_models_init()
95
    request = LoadLoRAAdapterRequest(lora_name="adapter1",
96
                                     lora_path="/path/to/adapter1")
97
98
    response = await serving_models.load_lora_adapter(request)
    assert len(serving_models.lora_requests) == 1
99

100
    request = UnloadLoRAAdapterRequest(lora_name="adapter1")
101
    response = await serving_models.unload_lora_adapter(request)
102
103
    assert response == LORA_UNLOADING_SUCCESS_MESSAGE.format(
        lora_name='adapter1')
104
    assert len(serving_models.lora_requests) == 0
105
106
107
108


@pytest.mark.asyncio
async def test_unload_lora_adapter_missing_fields():
109
    serving_models = await _async_serving_models_init()
110
    request = UnloadLoRAAdapterRequest(lora_name="", lora_int_id=None)
111
    response = await serving_models.unload_lora_adapter(request)
112
    assert isinstance(response, ErrorResponse)
113
114
    assert response.error.type == "InvalidUserInput"
    assert response.error.code == HTTPStatus.BAD_REQUEST
115
116
117
118


@pytest.mark.asyncio
async def test_unload_lora_adapter_not_found():
119
    serving_models = await _async_serving_models_init()
120
    request = UnloadLoRAAdapterRequest(lora_name="nonexistent_adapter")
121
    response = await serving_models.unload_lora_adapter(request)
122
    assert isinstance(response, ErrorResponse)
123
124
    assert response.error.type == "NotFoundError"
    assert response.error.code == HTTPStatus.NOT_FOUND