test_serving_models.py 5.07 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
5
6
7
8
from http import HTTPStatus
from unittest.mock import MagicMock

import pytest

from vllm.config import ModelConfig
9
from vllm.engine.protocol import EngineClient
10
from vllm.entrypoints.openai.protocol import (ErrorResponse,
11
12
                                              LoadLoRAAdapterRequest,
                                              UnloadLoRAAdapterRequest)
13
14
from vllm.entrypoints.openai.serving_models import (BaseModelPath,
                                                    OpenAIServingModels)
15
from vllm.lora.request import LoRARequest
16

17
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
18
BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)]
19
20
21
22
23
24
LORA_LOADING_SUCCESS_MESSAGE = (
    "Success: LoRA adapter '{lora_name}' added successfully.")
LORA_UNLOADING_SUCCESS_MESSAGE = (
    "Success: LoRA adapter '{lora_name}' removed successfully.")


25
async def _async_serving_models_init() -> OpenAIServingModels:
26
    mock_model_config = MagicMock(spec=ModelConfig)
27
    mock_engine_client = MagicMock(spec=EngineClient)
28
29
30
    # Set the max_model_len attribute to avoid missing attribute
    mock_model_config.max_model_len = 2048

31
32
    serving_models = OpenAIServingModels(engine_client=mock_engine_client,
                                         base_model_paths=BASE_MODEL_PATHS,
33
34
35
                                         model_config=mock_model_config,
                                         lora_modules=None,
                                         prompt_adapters=None)
36
    await serving_models.init_static_loras()
37
38

    return serving_models
39
40


41
42
@pytest.mark.asyncio
async def test_serving_model_name():
43
44
    serving_models = await _async_serving_models_init()
    assert serving_models.model_name(None) == MODEL_NAME
45
46
47
    request = LoRARequest(lora_name="adapter",
                          lora_path="/path/to/adapter2",
                          lora_int_id=1)
48
    assert serving_models.model_name(request) == request.lora_name
49
50


51
52
@pytest.mark.asyncio
async def test_load_lora_adapter_success():
53
    serving_models = await _async_serving_models_init()
54
    request = LoadLoRAAdapterRequest(lora_name="adapter",
55
                                     lora_path="/path/to/adapter2")
56
    response = await serving_models.load_lora_adapter(request)
57
    assert response == LORA_LOADING_SUCCESS_MESSAGE.format(lora_name='adapter')
58
59
    assert len(serving_models.lora_requests) == 1
    assert serving_models.lora_requests[0].lora_name == "adapter"
60
61
62
63


@pytest.mark.asyncio
async def test_load_lora_adapter_missing_fields():
64
    serving_models = await _async_serving_models_init()
65
    request = LoadLoRAAdapterRequest(lora_name="", lora_path="")
66
    response = await serving_models.load_lora_adapter(request)
67
68
69
70
71
72
73
    assert isinstance(response, ErrorResponse)
    assert response.type == "InvalidUserInput"
    assert response.code == HTTPStatus.BAD_REQUEST


@pytest.mark.asyncio
async def test_load_lora_adapter_duplicate():
74
    serving_models = await _async_serving_models_init()
75
    request = LoadLoRAAdapterRequest(lora_name="adapter1",
76
                                     lora_path="/path/to/adapter1")
77
    response = await serving_models.load_lora_adapter(request)
78
79
    assert response == LORA_LOADING_SUCCESS_MESSAGE.format(
        lora_name='adapter1')
80
    assert len(serving_models.lora_requests) == 1
81

82
    request = LoadLoRAAdapterRequest(lora_name="adapter1",
83
                                     lora_path="/path/to/adapter1")
84
    response = await serving_models.load_lora_adapter(request)
85
86
87
    assert isinstance(response, ErrorResponse)
    assert response.type == "InvalidUserInput"
    assert response.code == HTTPStatus.BAD_REQUEST
88
    assert len(serving_models.lora_requests) == 1
89
90
91
92


@pytest.mark.asyncio
async def test_unload_lora_adapter_success():
93
    serving_models = await _async_serving_models_init()
94
    request = LoadLoRAAdapterRequest(lora_name="adapter1",
95
                                     lora_path="/path/to/adapter1")
96
97
    response = await serving_models.load_lora_adapter(request)
    assert len(serving_models.lora_requests) == 1
98

99
    request = UnloadLoRAAdapterRequest(lora_name="adapter1")
100
    response = await serving_models.unload_lora_adapter(request)
101
102
    assert response == LORA_UNLOADING_SUCCESS_MESSAGE.format(
        lora_name='adapter1')
103
    assert len(serving_models.lora_requests) == 0
104
105
106
107


@pytest.mark.asyncio
async def test_unload_lora_adapter_missing_fields():
108
    serving_models = await _async_serving_models_init()
109
    request = UnloadLoRAAdapterRequest(lora_name="", lora_int_id=None)
110
    response = await serving_models.unload_lora_adapter(request)
111
112
113
114
115
116
117
    assert isinstance(response, ErrorResponse)
    assert response.type == "InvalidUserInput"
    assert response.code == HTTPStatus.BAD_REQUEST


@pytest.mark.asyncio
async def test_unload_lora_adapter_not_found():
118
    serving_models = await _async_serving_models_init()
119
    request = UnloadLoRAAdapterRequest(lora_name="nonexistent_adapter")
120
    response = await serving_models.unload_lora_adapter(request)
121
    assert isinstance(response, ErrorResponse)
122
123
    assert response.type == "NotFoundError"
    assert response.code == HTTPStatus.NOT_FOUND