test_serving_models.py 5.22 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
from http import HTTPStatus
from unittest.mock import MagicMock

import pytest
8
import os
9
10

from vllm.config import ModelConfig
11
from vllm.engine.protocol import EngineClient
12
from vllm.entrypoints.openai.protocol import (ErrorResponse,
13
14
                                              LoadLoRAAdapterRequest,
                                              UnloadLoRAAdapterRequest)
15
16
from vllm.entrypoints.openai.serving_models import (BaseModelPath,
                                                    OpenAIServingModels)
17
from vllm.lora.request import LoRARequest
18
from ...utils import models_path_prefix
19

zhuwenwen's avatar
zhuwenwen committed
20
MODEL_NAME = os.path.join(models_path_prefix, "meta-llama/Llama-3.2-1B-Instruct")
21
BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)]
22
23
24
25
26
27
LORA_LOADING_SUCCESS_MESSAGE = (
    "Success: LoRA adapter '{lora_name}' added successfully.")
LORA_UNLOADING_SUCCESS_MESSAGE = (
    "Success: LoRA adapter '{lora_name}' removed successfully.")


28
async def _async_serving_models_init() -> OpenAIServingModels:
29
    mock_model_config = MagicMock(spec=ModelConfig)
30
    mock_engine_client = MagicMock(spec=EngineClient)
31
32
33
    # Set the max_model_len attribute to avoid missing attribute
    mock_model_config.max_model_len = 2048

34
35
    serving_models = OpenAIServingModels(engine_client=mock_engine_client,
                                         base_model_paths=BASE_MODEL_PATHS,
36
                                         model_config=mock_model_config,
37
                                         lora_modules=None)
38
    await serving_models.init_static_loras()
39
40

    return serving_models
41
42


43
44
@pytest.mark.asyncio
async def test_serving_model_name():
45
46
    serving_models = await _async_serving_models_init()
    assert serving_models.model_name(None) == MODEL_NAME
47
48
49
    request = LoRARequest(lora_name="adapter",
                          lora_path="/path/to/adapter2",
                          lora_int_id=1)
50
    assert serving_models.model_name(request) == request.lora_name
51
52


53
54
@pytest.mark.asyncio
async def test_load_lora_adapter_success():
55
    serving_models = await _async_serving_models_init()
56
    request = LoadLoRAAdapterRequest(lora_name="adapter",
57
                                     lora_path="/path/to/adapter2")
58
    response = await serving_models.load_lora_adapter(request)
59
    assert response == LORA_LOADING_SUCCESS_MESSAGE.format(lora_name='adapter')
60
    assert len(serving_models.lora_requests) == 1
61
62
    assert "adapter" in serving_models.lora_requests
    assert serving_models.lora_requests["adapter"].lora_name == "adapter"
63
64
65
66


@pytest.mark.asyncio
async def test_load_lora_adapter_missing_fields():
67
    serving_models = await _async_serving_models_init()
68
    request = LoadLoRAAdapterRequest(lora_name="", lora_path="")
69
    response = await serving_models.load_lora_adapter(request)
70
71
72
73
74
75
76
    assert isinstance(response, ErrorResponse)
    assert response.type == "InvalidUserInput"
    assert response.code == HTTPStatus.BAD_REQUEST


@pytest.mark.asyncio
async def test_load_lora_adapter_duplicate():
77
    serving_models = await _async_serving_models_init()
78
    request = LoadLoRAAdapterRequest(lora_name="adapter1",
79
                                     lora_path="/path/to/adapter1")
80
    response = await serving_models.load_lora_adapter(request)
81
82
    assert response == LORA_LOADING_SUCCESS_MESSAGE.format(
        lora_name='adapter1')
83
    assert len(serving_models.lora_requests) == 1
84

85
    request = LoadLoRAAdapterRequest(lora_name="adapter1",
86
                                     lora_path="/path/to/adapter1")
87
    response = await serving_models.load_lora_adapter(request)
88
89
90
    assert isinstance(response, ErrorResponse)
    assert response.type == "InvalidUserInput"
    assert response.code == HTTPStatus.BAD_REQUEST
91
    assert len(serving_models.lora_requests) == 1
92
93
94
95


@pytest.mark.asyncio
async def test_unload_lora_adapter_success():
96
    serving_models = await _async_serving_models_init()
97
    request = LoadLoRAAdapterRequest(lora_name="adapter1",
98
                                     lora_path="/path/to/adapter1")
99
100
    response = await serving_models.load_lora_adapter(request)
    assert len(serving_models.lora_requests) == 1
101

102
    request = UnloadLoRAAdapterRequest(lora_name="adapter1")
103
    response = await serving_models.unload_lora_adapter(request)
104
105
    assert response == LORA_UNLOADING_SUCCESS_MESSAGE.format(
        lora_name='adapter1')
106
    assert len(serving_models.lora_requests) == 0
107
108
109
110


@pytest.mark.asyncio
async def test_unload_lora_adapter_missing_fields():
111
    serving_models = await _async_serving_models_init()
112
    request = UnloadLoRAAdapterRequest(lora_name="", lora_int_id=None)
113
    response = await serving_models.unload_lora_adapter(request)
114
115
116
117
118
119
120
    assert isinstance(response, ErrorResponse)
    assert response.type == "InvalidUserInput"
    assert response.code == HTTPStatus.BAD_REQUEST


@pytest.mark.asyncio
async def test_unload_lora_adapter_not_found():
121
    serving_models = await _async_serving_models_init()
122
    request = UnloadLoRAAdapterRequest(lora_name="nonexistent_adapter")
123
    response = await serving_models.unload_lora_adapter(request)
124
    assert isinstance(response, ErrorResponse)
125
126
    assert response.type == "NotFoundError"
    assert response.code == HTTPStatus.NOT_FOUND