test_serving_engine.py 4.39 KB
Newer Older
1
2
3
4
5
6
from http import HTTPStatus
from unittest.mock import MagicMock

import pytest

from vllm.config import ModelConfig
7
from vllm.engine.protocol import EngineClient
8
9
10
11
12
13
14
15
16
17
18
19
20
from vllm.entrypoints.openai.protocol import (ErrorResponse,
                                              LoadLoraAdapterRequest,
                                              UnloadLoraAdapterRequest)
from vllm.entrypoints.openai.serving_engine import OpenAIServing

MODEL_NAME = "meta-llama/Llama-2-7b"
LORA_LOADING_SUCCESS_MESSAGE = (
    "Success: LoRA adapter '{lora_name}' added successfully.")
LORA_UNLOADING_SUCCESS_MESSAGE = (
    "Success: LoRA adapter '{lora_name}' removed successfully.")


async def _async_serving_engine_init():
21
    mock_engine_client = MagicMock(spec=EngineClient)
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
    mock_model_config = MagicMock(spec=ModelConfig)
    # Set the max_model_len attribute to avoid missing attribute
    mock_model_config.max_model_len = 2048

    serving_engine = OpenAIServing(mock_engine_client,
                                   mock_model_config,
                                   served_model_names=[MODEL_NAME],
                                   lora_modules=None,
                                   prompt_adapters=None,
                                   request_logger=None)
    return serving_engine


@pytest.mark.asyncio
async def test_load_lora_adapter_success():
    serving_engine = await _async_serving_engine_init()
    request = LoadLoraAdapterRequest(lora_name="adapter",
                                     lora_path="/path/to/adapter2")
    response = await serving_engine.load_lora_adapter(request)
    assert response == LORA_LOADING_SUCCESS_MESSAGE.format(lora_name='adapter')
    assert len(serving_engine.lora_requests) == 1
    assert serving_engine.lora_requests[0].lora_name == "adapter"


@pytest.mark.asyncio
async def test_load_lora_adapter_missing_fields():
    serving_engine = await _async_serving_engine_init()
    request = LoadLoraAdapterRequest(lora_name="", lora_path="")
    response = await serving_engine.load_lora_adapter(request)
    assert isinstance(response, ErrorResponse)
    assert response.type == "InvalidUserInput"
    assert response.code == HTTPStatus.BAD_REQUEST


@pytest.mark.asyncio
async def test_load_lora_adapter_duplicate():
    serving_engine = await _async_serving_engine_init()
    request = LoadLoraAdapterRequest(lora_name="adapter1",
                                     lora_path="/path/to/adapter1")
    response = await serving_engine.load_lora_adapter(request)
    assert response == LORA_LOADING_SUCCESS_MESSAGE.format(
        lora_name='adapter1')
    assert len(serving_engine.lora_requests) == 1

    request = LoadLoraAdapterRequest(lora_name="adapter1",
                                     lora_path="/path/to/adapter1")
    response = await serving_engine.load_lora_adapter(request)
    assert isinstance(response, ErrorResponse)
    assert response.type == "InvalidUserInput"
    assert response.code == HTTPStatus.BAD_REQUEST
    assert len(serving_engine.lora_requests) == 1


@pytest.mark.asyncio
async def test_unload_lora_adapter_success():
    serving_engine = await _async_serving_engine_init()
    request = LoadLoraAdapterRequest(lora_name="adapter1",
                                     lora_path="/path/to/adapter1")
    response = await serving_engine.load_lora_adapter(request)
    assert len(serving_engine.lora_requests) == 1

    request = UnloadLoraAdapterRequest(lora_name="adapter1")
    response = await serving_engine.unload_lora_adapter(request)
    assert response == LORA_UNLOADING_SUCCESS_MESSAGE.format(
        lora_name='adapter1')
    assert len(serving_engine.lora_requests) == 0


@pytest.mark.asyncio
async def test_unload_lora_adapter_missing_fields():
    serving_engine = await _async_serving_engine_init()
    request = UnloadLoraAdapterRequest(lora_name="", lora_int_id=None)
    response = await serving_engine.unload_lora_adapter(request)
    assert isinstance(response, ErrorResponse)
    assert response.type == "InvalidUserInput"
    assert response.code == HTTPStatus.BAD_REQUEST


@pytest.mark.asyncio
async def test_unload_lora_adapter_not_found():
    serving_engine = await _async_serving_engine_init()
    request = UnloadLoraAdapterRequest(lora_name="nonexistent_adapter")
    response = await serving_engine.unload_lora_adapter(request)
    assert isinstance(response, ErrorResponse)
    assert response.type == "InvalidUserInput"
    assert response.code == HTTPStatus.BAD_REQUEST