test_serving_models.py 4.81 KB
Newer Older
1
2
3
4
5
6
7
8
9
from http import HTTPStatus
from unittest.mock import MagicMock

import pytest

from vllm.config import ModelConfig
from vllm.entrypoints.openai.protocol import (ErrorResponse,
                                              LoadLoraAdapterRequest,
                                              UnloadLoraAdapterRequest)
10
11
from vllm.entrypoints.openai.serving_models import (BaseModelPath,
                                                    OpenAIServingModels)
12
from vllm.lora.request import LoRARequest
13
14

MODEL_NAME = "meta-llama/Llama-2-7b"
15
BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)]
16
17
18
19
20
21
LORA_LOADING_SUCCESS_MESSAGE = (
    "Success: LoRA adapter '{lora_name}' added successfully.")
LORA_UNLOADING_SUCCESS_MESSAGE = (
    "Success: LoRA adapter '{lora_name}' removed successfully.")


22
async def _async_serving_models_init() -> OpenAIServingModels:
23
24
25
26
    mock_model_config = MagicMock(spec=ModelConfig)
    # Set the max_model_len attribute to avoid missing attribute
    mock_model_config.max_model_len = 2048

27
28
29
30
31
32
    serving_models = OpenAIServingModels(base_model_paths=BASE_MODEL_PATHS,
                                         model_config=mock_model_config,
                                         lora_modules=None,
                                         prompt_adapters=None)

    return serving_models
33
34


35
36
@pytest.mark.asyncio
async def test_serving_model_name():
37
38
    serving_models = await _async_serving_models_init()
    assert serving_models.model_name(None) == MODEL_NAME
39
40
41
    request = LoRARequest(lora_name="adapter",
                          lora_path="/path/to/adapter2",
                          lora_int_id=1)
42
    assert serving_models.model_name(request) == request.lora_name
43
44


45
46
@pytest.mark.asyncio
async def test_load_lora_adapter_success():
47
    serving_models = await _async_serving_models_init()
48
49
    request = LoadLoraAdapterRequest(lora_name="adapter",
                                     lora_path="/path/to/adapter2")
50
    response = await serving_models.load_lora_adapter(request)
51
    assert response == LORA_LOADING_SUCCESS_MESSAGE.format(lora_name='adapter')
52
53
    assert len(serving_models.lora_requests) == 1
    assert serving_models.lora_requests[0].lora_name == "adapter"
54
55
56
57


@pytest.mark.asyncio
async def test_load_lora_adapter_missing_fields():
58
    serving_models = await _async_serving_models_init()
59
    request = LoadLoraAdapterRequest(lora_name="", lora_path="")
60
    response = await serving_models.load_lora_adapter(request)
61
62
63
64
65
66
67
    assert isinstance(response, ErrorResponse)
    assert response.type == "InvalidUserInput"
    assert response.code == HTTPStatus.BAD_REQUEST


@pytest.mark.asyncio
async def test_load_lora_adapter_duplicate():
68
    serving_models = await _async_serving_models_init()
69
70
    request = LoadLoraAdapterRequest(lora_name="adapter1",
                                     lora_path="/path/to/adapter1")
71
    response = await serving_models.load_lora_adapter(request)
72
73
    assert response == LORA_LOADING_SUCCESS_MESSAGE.format(
        lora_name='adapter1')
74
    assert len(serving_models.lora_requests) == 1
75
76
77

    request = LoadLoraAdapterRequest(lora_name="adapter1",
                                     lora_path="/path/to/adapter1")
78
    response = await serving_models.load_lora_adapter(request)
79
80
81
    assert isinstance(response, ErrorResponse)
    assert response.type == "InvalidUserInput"
    assert response.code == HTTPStatus.BAD_REQUEST
82
    assert len(serving_models.lora_requests) == 1
83
84
85
86


@pytest.mark.asyncio
async def test_unload_lora_adapter_success():
87
    serving_models = await _async_serving_models_init()
88
89
    request = LoadLoraAdapterRequest(lora_name="adapter1",
                                     lora_path="/path/to/adapter1")
90
91
    response = await serving_models.load_lora_adapter(request)
    assert len(serving_models.lora_requests) == 1
92
93

    request = UnloadLoraAdapterRequest(lora_name="adapter1")
94
    response = await serving_models.unload_lora_adapter(request)
95
96
    assert response == LORA_UNLOADING_SUCCESS_MESSAGE.format(
        lora_name='adapter1')
97
    assert len(serving_models.lora_requests) == 0
98
99
100
101


@pytest.mark.asyncio
async def test_unload_lora_adapter_missing_fields():
102
    serving_models = await _async_serving_models_init()
103
    request = UnloadLoraAdapterRequest(lora_name="", lora_int_id=None)
104
    response = await serving_models.unload_lora_adapter(request)
105
106
107
108
109
110
111
    assert isinstance(response, ErrorResponse)
    assert response.type == "InvalidUserInput"
    assert response.code == HTTPStatus.BAD_REQUEST


@pytest.mark.asyncio
async def test_unload_lora_adapter_not_found():
112
    serving_models = await _async_serving_models_init()
113
    request = UnloadLoraAdapterRequest(lora_name="nonexistent_adapter")
114
    response = await serving_models.unload_lora_adapter(request)
115
116
117
    assert isinstance(response, ErrorResponse)
    assert response.type == "InvalidUserInput"
    assert response.code == HTTPStatus.BAD_REQUEST