test_serving_models.py 5.01 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
from http import HTTPStatus
from unittest.mock import MagicMock

import pytest
8
import os
9
10

from vllm.config import ModelConfig
11
from vllm.engine.protocol import EngineClient
12
13
14
15
16
17
from vllm.entrypoints.openai.protocol import (
    ErrorResponse,
    LoadLoRAAdapterRequest,
    UnloadLoRAAdapterRequest,
)
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels
18
from vllm.lora.request import LoRARequest
19
from ...utils import models_path_prefix
20

21
22

MODEL_NAME = os.path.join(models_path_prefix, "hmellor/tiny-random-LlamaForCausalLM")
23
BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)]
24
LORA_LOADING_SUCCESS_MESSAGE = "Success: LoRA adapter '{lora_name}' added successfully."
25
LORA_UNLOADING_SUCCESS_MESSAGE = (
26
27
    "Success: LoRA adapter '{lora_name}' removed successfully."
)
28
29


30
async def _async_serving_models_init() -> OpenAIServingModels:
31
    mock_engine_client = MagicMock(spec=EngineClient)
32
    # Set the max_model_len attribute to avoid missing attribute
33
    mock_model_config = MagicMock(spec=ModelConfig)
34
    mock_model_config.max_model_len = 2048
35
    mock_engine_client.model_config = mock_model_config
36
    mock_engine_client.input_processor = MagicMock()
37
    mock_engine_client.io_processor = MagicMock()
38

39
40
41
42
43
    serving_models = OpenAIServingModels(
        engine_client=mock_engine_client,
        base_model_paths=BASE_MODEL_PATHS,
        lora_modules=None,
    )
44
    await serving_models.init_static_loras()
45
46

    return serving_models
47
48


49
50
@pytest.mark.asyncio
async def test_serving_model_name():
51
52
    serving_models = await _async_serving_models_init()
    assert serving_models.model_name(None) == MODEL_NAME
53
54
55
    request = LoRARequest(
        lora_name="adapter", lora_path="/path/to/adapter2", lora_int_id=1
    )
56
    assert serving_models.model_name(request) == request.lora_name
57
58


59
60
@pytest.mark.asyncio
async def test_load_lora_adapter_success():
61
    serving_models = await _async_serving_models_init()
62
    request = LoadLoRAAdapterRequest(lora_name="adapter", lora_path="/path/to/adapter2")
63
    response = await serving_models.load_lora_adapter(request)
64
    assert response == LORA_LOADING_SUCCESS_MESSAGE.format(lora_name="adapter")
65
    assert len(serving_models.lora_requests) == 1
66
67
    assert "adapter" in serving_models.lora_requests
    assert serving_models.lora_requests["adapter"].lora_name == "adapter"
68
69
70
71


@pytest.mark.asyncio
async def test_load_lora_adapter_missing_fields():
72
    serving_models = await _async_serving_models_init()
73
    request = LoadLoRAAdapterRequest(lora_name="", lora_path="")
74
    response = await serving_models.load_lora_adapter(request)
75
    assert isinstance(response, ErrorResponse)
76
77
    assert response.error.type == "InvalidUserInput"
    assert response.error.code == HTTPStatus.BAD_REQUEST
78
79
80
81


@pytest.mark.asyncio
async def test_load_lora_adapter_duplicate():
82
    serving_models = await _async_serving_models_init()
83
84
85
    request = LoadLoRAAdapterRequest(
        lora_name="adapter1", lora_path="/path/to/adapter1"
    )
86
    response = await serving_models.load_lora_adapter(request)
87
    assert response == LORA_LOADING_SUCCESS_MESSAGE.format(lora_name="adapter1")
88
    assert len(serving_models.lora_requests) == 1
89

90
91
92
    request = LoadLoRAAdapterRequest(
        lora_name="adapter1", lora_path="/path/to/adapter1"
    )
93
    response = await serving_models.load_lora_adapter(request)
94
    assert isinstance(response, ErrorResponse)
95
96
    assert response.error.type == "InvalidUserInput"
    assert response.error.code == HTTPStatus.BAD_REQUEST
97
    assert len(serving_models.lora_requests) == 1
98
99
100
101


@pytest.mark.asyncio
async def test_unload_lora_adapter_success():
102
    serving_models = await _async_serving_models_init()
103
104
105
    request = LoadLoRAAdapterRequest(
        lora_name="adapter1", lora_path="/path/to/adapter1"
    )
106
107
    response = await serving_models.load_lora_adapter(request)
    assert len(serving_models.lora_requests) == 1
108

109
    request = UnloadLoRAAdapterRequest(lora_name="adapter1")
110
    response = await serving_models.unload_lora_adapter(request)
111
    assert response == LORA_UNLOADING_SUCCESS_MESSAGE.format(lora_name="adapter1")
112
    assert len(serving_models.lora_requests) == 0
113
114
115
116


@pytest.mark.asyncio
async def test_unload_lora_adapter_missing_fields():
117
    serving_models = await _async_serving_models_init()
118
    request = UnloadLoRAAdapterRequest(lora_name="", lora_int_id=None)
119
    response = await serving_models.unload_lora_adapter(request)
120
    assert isinstance(response, ErrorResponse)
121
122
    assert response.error.type == "InvalidUserInput"
    assert response.error.code == HTTPStatus.BAD_REQUEST
123
124
125
126


@pytest.mark.asyncio
async def test_unload_lora_adapter_not_found():
127
    serving_models = await _async_serving_models_init()
128
    request = UnloadLoRAAdapterRequest(lora_name="nonexistent_adapter")
129
    response = await serving_models.unload_lora_adapter(request)
130
    assert isinstance(response, ErrorResponse)
131
132
    assert response.error.type == "NotFoundError"
    assert response.error.code == HTTPStatus.NOT_FOUND