test_serving_models.py 5.02 KB
Newer Older
1
2
3
4
5
6
from http import HTTPStatus
from unittest.mock import MagicMock

import pytest

from vllm.config import ModelConfig
7
from vllm.engine.protocol import EngineClient
8
9
10
from vllm.entrypoints.openai.protocol import (ErrorResponse,
                                              LoadLoraAdapterRequest,
                                              UnloadLoraAdapterRequest)
11
12
from vllm.entrypoints.openai.serving_models import (BaseModelPath,
                                                    OpenAIServingModels)
13
from vllm.lora.request import LoRARequest
14
15

MODEL_NAME = "meta-llama/Llama-2-7b"
16
BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)]
17
18
19
20
21
22
LORA_LOADING_SUCCESS_MESSAGE = (
    "Success: LoRA adapter '{lora_name}' added successfully.")
LORA_UNLOADING_SUCCESS_MESSAGE = (
    "Success: LoRA adapter '{lora_name}' removed successfully.")


23
async def _async_serving_models_init() -> OpenAIServingModels:
24
    mock_model_config = MagicMock(spec=ModelConfig)
25
    mock_engine_client = MagicMock(spec=EngineClient)
26
27
28
    # Set the max_model_len attribute to avoid missing attribute
    mock_model_config.max_model_len = 2048

29
30
    serving_models = OpenAIServingModels(engine_client=mock_engine_client,
                                         base_model_paths=BASE_MODEL_PATHS,
31
32
33
                                         model_config=mock_model_config,
                                         lora_modules=None,
                                         prompt_adapters=None)
34
    await serving_models.init_static_loras()
35
36

    return serving_models
37
38


39
40
@pytest.mark.asyncio
async def test_serving_model_name():
41
42
    serving_models = await _async_serving_models_init()
    assert serving_models.model_name(None) == MODEL_NAME
43
44
45
    request = LoRARequest(lora_name="adapter",
                          lora_path="/path/to/adapter2",
                          lora_int_id=1)
46
    assert serving_models.model_name(request) == request.lora_name
47
48


49
50
@pytest.mark.asyncio
async def test_load_lora_adapter_success():
51
    serving_models = await _async_serving_models_init()
52
53
    request = LoadLoraAdapterRequest(lora_name="adapter",
                                     lora_path="/path/to/adapter2")
54
    response = await serving_models.load_lora_adapter(request)
55
    assert response == LORA_LOADING_SUCCESS_MESSAGE.format(lora_name='adapter')
56
57
    assert len(serving_models.lora_requests) == 1
    assert serving_models.lora_requests[0].lora_name == "adapter"
58
59
60
61


@pytest.mark.asyncio
async def test_load_lora_adapter_missing_fields():
62
    serving_models = await _async_serving_models_init()
63
    request = LoadLoraAdapterRequest(lora_name="", lora_path="")
64
    response = await serving_models.load_lora_adapter(request)
65
66
67
68
69
70
71
    assert isinstance(response, ErrorResponse)
    assert response.type == "InvalidUserInput"
    assert response.code == HTTPStatus.BAD_REQUEST


@pytest.mark.asyncio
async def test_load_lora_adapter_duplicate():
72
    serving_models = await _async_serving_models_init()
73
74
    request = LoadLoraAdapterRequest(lora_name="adapter1",
                                     lora_path="/path/to/adapter1")
75
    response = await serving_models.load_lora_adapter(request)
76
77
    assert response == LORA_LOADING_SUCCESS_MESSAGE.format(
        lora_name='adapter1')
78
    assert len(serving_models.lora_requests) == 1
79
80
81

    request = LoadLoraAdapterRequest(lora_name="adapter1",
                                     lora_path="/path/to/adapter1")
82
    response = await serving_models.load_lora_adapter(request)
83
84
85
    assert isinstance(response, ErrorResponse)
    assert response.type == "InvalidUserInput"
    assert response.code == HTTPStatus.BAD_REQUEST
86
    assert len(serving_models.lora_requests) == 1
87
88
89
90


@pytest.mark.asyncio
async def test_unload_lora_adapter_success():
91
    serving_models = await _async_serving_models_init()
92
93
    request = LoadLoraAdapterRequest(lora_name="adapter1",
                                     lora_path="/path/to/adapter1")
94
95
    response = await serving_models.load_lora_adapter(request)
    assert len(serving_models.lora_requests) == 1
96
97

    request = UnloadLoraAdapterRequest(lora_name="adapter1")
98
    response = await serving_models.unload_lora_adapter(request)
99
100
    assert response == LORA_UNLOADING_SUCCESS_MESSAGE.format(
        lora_name='adapter1')
101
    assert len(serving_models.lora_requests) == 0
102
103
104
105


@pytest.mark.asyncio
async def test_unload_lora_adapter_missing_fields():
106
    serving_models = await _async_serving_models_init()
107
    request = UnloadLoraAdapterRequest(lora_name="", lora_int_id=None)
108
    response = await serving_models.unload_lora_adapter(request)
109
110
111
112
113
114
115
    assert isinstance(response, ErrorResponse)
    assert response.type == "InvalidUserInput"
    assert response.code == HTTPStatus.BAD_REQUEST


@pytest.mark.asyncio
async def test_unload_lora_adapter_not_found():
116
    serving_models = await _async_serving_models_init()
117
    request = UnloadLoraAdapterRequest(lora_name="nonexistent_adapter")
118
    response = await serving_models.unload_lora_adapter(request)
119
    assert isinstance(response, ErrorResponse)
120
121
    assert response.type == "NotFoundError"
    assert response.code == HTTPStatus.NOT_FOUND