test_serving_models.py 5.12 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
from http import HTTPStatus
from unittest.mock import MagicMock

import pytest

9
from vllm.config import ModelConfig, RendererConfig
10
from vllm.engine.protocol import EngineClient
11
12
13
14
15
16
from vllm.entrypoints.openai.protocol import (
    ErrorResponse,
    LoadLoRAAdapterRequest,
    UnloadLoRAAdapterRequest,
)
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels
17
from vllm.lora.request import LoRARequest
18

19
MODEL_NAME = "hmellor/tiny-random-LlamaForCausalLM"
20
BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)]
21
LORA_LOADING_SUCCESS_MESSAGE = "Success: LoRA adapter '{lora_name}' added successfully."
22
LORA_UNLOADING_SUCCESS_MESSAGE = (
23
24
    "Success: LoRA adapter '{lora_name}' removed successfully."
)
25
26


27
async def _async_serving_models_init() -> OpenAIServingModels:
28
    mock_engine_client = MagicMock(spec=EngineClient)
29
    # Set the max_model_len attribute to avoid missing attribute
30

31
    mock_model_config = MagicMock(spec=ModelConfig)
32
    mock_model_config.max_model_len = 2048
33
34
35
36

    mock_renderer_config = MagicMock(spec=RendererConfig)
    mock_renderer_config.model_config = mock_model_config

37
    mock_engine_client.model_config = mock_model_config
38
    mock_engine_client.renderer_config = mock_renderer_config
39
    mock_engine_client.input_processor = MagicMock()
40
    mock_engine_client.io_processor = MagicMock()
41

42
43
44
45
46
    serving_models = OpenAIServingModels(
        engine_client=mock_engine_client,
        base_model_paths=BASE_MODEL_PATHS,
        lora_modules=None,
    )
47
    await serving_models.init_static_loras()
48
49

    return serving_models
50
51


52
53
@pytest.mark.asyncio
async def test_serving_model_name():
54
55
    serving_models = await _async_serving_models_init()
    assert serving_models.model_name(None) == MODEL_NAME
56
57
58
    request = LoRARequest(
        lora_name="adapter", lora_path="/path/to/adapter2", lora_int_id=1
    )
59
    assert serving_models.model_name(request) == request.lora_name
60
61


62
63
@pytest.mark.asyncio
async def test_load_lora_adapter_success():
64
    serving_models = await _async_serving_models_init()
65
    request = LoadLoRAAdapterRequest(lora_name="adapter", lora_path="/path/to/adapter2")
66
    response = await serving_models.load_lora_adapter(request)
67
    assert response == LORA_LOADING_SUCCESS_MESSAGE.format(lora_name="adapter")
68
    assert len(serving_models.lora_requests) == 1
69
70
    assert "adapter" in serving_models.lora_requests
    assert serving_models.lora_requests["adapter"].lora_name == "adapter"
71
72
73
74


@pytest.mark.asyncio
async def test_load_lora_adapter_missing_fields():
75
    serving_models = await _async_serving_models_init()
76
    request = LoadLoRAAdapterRequest(lora_name="", lora_path="")
77
    response = await serving_models.load_lora_adapter(request)
78
    assert isinstance(response, ErrorResponse)
79
80
    assert response.error.type == "InvalidUserInput"
    assert response.error.code == HTTPStatus.BAD_REQUEST
81
82
83
84


@pytest.mark.asyncio
async def test_load_lora_adapter_duplicate():
85
    serving_models = await _async_serving_models_init()
86
87
88
    request = LoadLoRAAdapterRequest(
        lora_name="adapter1", lora_path="/path/to/adapter1"
    )
89
    response = await serving_models.load_lora_adapter(request)
90
    assert response == LORA_LOADING_SUCCESS_MESSAGE.format(lora_name="adapter1")
91
    assert len(serving_models.lora_requests) == 1
92

93
94
95
    request = LoadLoRAAdapterRequest(
        lora_name="adapter1", lora_path="/path/to/adapter1"
    )
96
    response = await serving_models.load_lora_adapter(request)
97
    assert isinstance(response, ErrorResponse)
98
99
    assert response.error.type == "InvalidUserInput"
    assert response.error.code == HTTPStatus.BAD_REQUEST
100
    assert len(serving_models.lora_requests) == 1
101
102
103
104


@pytest.mark.asyncio
async def test_unload_lora_adapter_success():
105
    serving_models = await _async_serving_models_init()
106
107
108
    request = LoadLoRAAdapterRequest(
        lora_name="adapter1", lora_path="/path/to/adapter1"
    )
109
110
    response = await serving_models.load_lora_adapter(request)
    assert len(serving_models.lora_requests) == 1
111

112
    request = UnloadLoRAAdapterRequest(lora_name="adapter1")
113
    response = await serving_models.unload_lora_adapter(request)
114
    assert response == LORA_UNLOADING_SUCCESS_MESSAGE.format(lora_name="adapter1")
115
    assert len(serving_models.lora_requests) == 0
116
117
118
119


@pytest.mark.asyncio
async def test_unload_lora_adapter_missing_fields():
120
    serving_models = await _async_serving_models_init()
121
    request = UnloadLoRAAdapterRequest(lora_name="", lora_int_id=None)
122
    response = await serving_models.unload_lora_adapter(request)
123
    assert isinstance(response, ErrorResponse)
124
125
    assert response.error.type == "InvalidUserInput"
    assert response.error.code == HTTPStatus.BAD_REQUEST
126
127
128
129


@pytest.mark.asyncio
async def test_unload_lora_adapter_not_found():
130
    serving_models = await _async_serving_models_init()
131
    request = UnloadLoRAAdapterRequest(lora_name="nonexistent_adapter")
132
    response = await serving_models.unload_lora_adapter(request)
133
    assert isinstance(response, ErrorResponse)
134
135
    assert response.error.type == "NotFoundError"
    assert response.error.code == HTTPStatus.NOT_FOUND