test_serving_models.py 4.99 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
from http import HTTPStatus
from unittest.mock import MagicMock

import pytest

9
from vllm.config import ModelConfig
10
from vllm.engine.protocol import EngineClient
11
from vllm.entrypoints.openai.engine.protocol import (
12
    ErrorResponse,
13
14
15
)
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels
from vllm.entrypoints.serve.lora.protocol import (
16
17
18
    LoadLoRAAdapterRequest,
    UnloadLoRAAdapterRequest,
)
19
from vllm.lora.request import LoRARequest
20

21
MODEL_NAME = "hmellor/tiny-random-LlamaForCausalLM"
22
BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)]
23
LORA_LOADING_SUCCESS_MESSAGE = "Success: LoRA adapter '{lora_name}' added successfully."
24
LORA_UNLOADING_SUCCESS_MESSAGE = (
25
26
    "Success: LoRA adapter '{lora_name}' removed successfully."
)
27
28


29
async def _async_serving_models_init() -> OpenAIServingModels:
30
    mock_engine_client = MagicMock(spec=EngineClient)
31
    # Set the max_model_len attribute to avoid missing attribute
32
    mock_model_config = MagicMock(spec=ModelConfig)
33
    mock_model_config.max_model_len = 2048
34
    mock_engine_client.model_config = mock_model_config
35
    mock_engine_client.input_processor = MagicMock()
36
    mock_engine_client.io_processor = MagicMock()
37

38
39
40
41
42
    serving_models = OpenAIServingModels(
        engine_client=mock_engine_client,
        base_model_paths=BASE_MODEL_PATHS,
        lora_modules=None,
    )
43
    await serving_models.init_static_loras()
44
45

    return serving_models
46
47


48
49
@pytest.mark.asyncio
async def test_serving_model_name():
50
51
    serving_models = await _async_serving_models_init()
    assert serving_models.model_name(None) == MODEL_NAME
52
53
54
    request = LoRARequest(
        lora_name="adapter", lora_path="/path/to/adapter2", lora_int_id=1
    )
55
    assert serving_models.model_name(request) == request.lora_name
56
57


58
59
@pytest.mark.asyncio
async def test_load_lora_adapter_success():
60
    serving_models = await _async_serving_models_init()
61
    request = LoadLoRAAdapterRequest(lora_name="adapter", lora_path="/path/to/adapter2")
62
    response = await serving_models.load_lora_adapter(request)
63
    assert response == LORA_LOADING_SUCCESS_MESSAGE.format(lora_name="adapter")
64
    assert len(serving_models.lora_requests) == 1
65
66
    assert "adapter" in serving_models.lora_requests
    assert serving_models.lora_requests["adapter"].lora_name == "adapter"
67
68
69
70


@pytest.mark.asyncio
async def test_load_lora_adapter_missing_fields():
71
    serving_models = await _async_serving_models_init()
72
    request = LoadLoRAAdapterRequest(lora_name="", lora_path="")
73
    response = await serving_models.load_lora_adapter(request)
74
    assert isinstance(response, ErrorResponse)
75
76
    assert response.error.type == "InvalidUserInput"
    assert response.error.code == HTTPStatus.BAD_REQUEST
77
78
79
80


@pytest.mark.asyncio
async def test_load_lora_adapter_duplicate():
81
    serving_models = await _async_serving_models_init()
82
83
84
    request = LoadLoRAAdapterRequest(
        lora_name="adapter1", lora_path="/path/to/adapter1"
    )
85
    response = await serving_models.load_lora_adapter(request)
86
    assert response == LORA_LOADING_SUCCESS_MESSAGE.format(lora_name="adapter1")
87
    assert len(serving_models.lora_requests) == 1
88

89
90
91
    request = LoadLoRAAdapterRequest(
        lora_name="adapter1", lora_path="/path/to/adapter1"
    )
92
    response = await serving_models.load_lora_adapter(request)
93
    assert isinstance(response, ErrorResponse)
94
95
    assert response.error.type == "InvalidUserInput"
    assert response.error.code == HTTPStatus.BAD_REQUEST
96
    assert len(serving_models.lora_requests) == 1
97
98
99
100


@pytest.mark.asyncio
async def test_unload_lora_adapter_success():
101
    serving_models = await _async_serving_models_init()
102
103
104
    request = LoadLoRAAdapterRequest(
        lora_name="adapter1", lora_path="/path/to/adapter1"
    )
105
106
    response = await serving_models.load_lora_adapter(request)
    assert len(serving_models.lora_requests) == 1
107

108
    request = UnloadLoRAAdapterRequest(lora_name="adapter1")
109
    response = await serving_models.unload_lora_adapter(request)
110
    assert response == LORA_UNLOADING_SUCCESS_MESSAGE.format(lora_name="adapter1")
111
    assert len(serving_models.lora_requests) == 0
112
113
114
115


@pytest.mark.asyncio
async def test_unload_lora_adapter_missing_fields():
116
    serving_models = await _async_serving_models_init()
117
    request = UnloadLoRAAdapterRequest(lora_name="", lora_int_id=None)
118
    response = await serving_models.unload_lora_adapter(request)
119
    assert isinstance(response, ErrorResponse)
120
121
    assert response.error.type == "InvalidUserInput"
    assert response.error.code == HTTPStatus.BAD_REQUEST
122
123
124
125


@pytest.mark.asyncio
async def test_unload_lora_adapter_not_found():
126
    serving_models = await _async_serving_models_init()
127
    request = UnloadLoRAAdapterRequest(lora_name="nonexistent_adapter")
128
    response = await serving_models.unload_lora_adapter(request)
129
    assert isinstance(response, ErrorResponse)
130
131
    assert response.error.type == "NotFoundError"
    assert response.error.code == HTTPStatus.NOT_FOUND