test_lora_resolvers.py 7.54 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6

from contextlib import suppress
from dataclasses import dataclass, field
from http import HTTPStatus
7
from unittest.mock import AsyncMock, MagicMock
8
9
10

import pytest

11
from vllm.config.multimodal import MultiModalConfig
12
from vllm.entrypoints.openai.engine.protocol import CompletionRequest, ErrorResponse
13
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
14
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels
15
16
from vllm.lora.request import LoRARequest
from vllm.lora.resolver import LoRAResolver, LoRAResolverRegistry
17
from vllm.tokenizers import get_tokenizer
18
from vllm.v1.engine.async_llm import AsyncLLM
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33

MODEL_NAME = "openai-community/gpt2"
BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)]

MOCK_RESOLVER_NAME = "mock_test_resolver"


@dataclass
class MockHFConfig:
    model_type: str = "any"


@dataclass
class MockModelConfig:
    """Minimal mock ModelConfig for testing."""
34

35
    model: str = MODEL_NAME
36
    tokenizer: str = MODEL_NAME
37
    trust_remote_code: bool = False
38
    tokenizer_mode: str = "auto"
39
    max_model_len: int = 100
40
    tokenizer_revision: str | None = None
41
    multimodal_config: MultiModalConfig = field(default_factory=MultiModalConfig)
42
    hf_config: MockHFConfig = field(default_factory=MockHFConfig)
43
    logits_processors: list[str] | None = None
44
45
    logits_processor_pattern: str | None = None
    diff_sampling_param: dict | None = None
46
47
    allowed_local_media_path: str = ""
    allowed_media_domains: list[str] | None = None
48
49
    encoder_config = None
    generation_config: str = "auto"
50
    skip_tokenizer_init: bool = False
51
52
53
54
55
56

    def get_diff_sampling_param(self):
        return self.diff_sampling_param or {}


class MockLoRAResolver(LoRAResolver):
57
58
    async def resolve_lora(
        self, base_model_name: str, lora_name: str
59
    ) -> LoRARequest | None:
60
        if lora_name == "test-lora":
61
62
63
            return LoRARequest(
                lora_name="test-lora",
                lora_int_id=1,
64
                lora_path="/fake/path/test-lora",
65
            )
66
        elif lora_name == "invalid-lora":
67
68
69
            return LoRARequest(
                lora_name="invalid-lora",
                lora_int_id=2,
70
                lora_path="/fake/path/invalid-lora",
71
            )
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
        return None


@pytest.fixture(autouse=True)
def register_mock_resolver():
    """Fixture to register and unregister the mock LoRA resolver."""
    resolver = MockLoRAResolver()
    LoRAResolverRegistry.register_resolver(MOCK_RESOLVER_NAME, resolver)
    yield
    # Cleanup: remove the resolver after the test runs
    if MOCK_RESOLVER_NAME in LoRAResolverRegistry.resolvers:
        del LoRAResolverRegistry.resolvers[MOCK_RESOLVER_NAME]


@pytest.fixture
def mock_serving_setup():
    """Provides a mocked engine and serving completion instance."""
89
    mock_engine = MagicMock(spec=AsyncLLM)
90
91
    mock_engine.errored = False

92
93
94
95
    tokenizer = get_tokenizer(MODEL_NAME)
    mock_engine.get_tokenizer = AsyncMock(return_value=tokenizer)

    async def mock_add_lora_side_effect(lora_request: LoRARequest):
96
97
98
        """Simulate engine behavior when adding LoRAs."""
        if lora_request.lora_name == "test-lora":
            # Simulate successful addition
99
100
            return True
        if lora_request.lora_name == "invalid-lora":
101
            # Simulate failure during addition (e.g. invalid format)
102
            raise ValueError(f"Simulated failure adding LoRA: {lora_request.lora_name}")
103
104
105
106
107
108
109
110
        return True

    mock_engine.add_lora = AsyncMock(side_effect=mock_add_lora_side_effect)

    async def mock_generate(*args, **kwargs):
        for _ in []:
            yield _

111
    mock_engine.generate = MagicMock(spec=AsyncLLM.generate, side_effect=mock_generate)
112
113
114
115

    mock_engine.generate.reset_mock()
    mock_engine.add_lora.reset_mock()

116
    mock_engine.model_config = MockModelConfig()
117
    mock_engine.input_processor = MagicMock()
118
119
    mock_engine.io_processor = MagicMock()

120
121
122
123
    models = OpenAIServingModels(
        engine_client=mock_engine,
        base_model_paths=BASE_MODEL_PATHS,
    )
124

125
    serving_completion = OpenAIServingCompletion(
126
        mock_engine, models, request_logger=None
127
    )
128

129
130
131
    serving_completion._process_inputs = AsyncMock(
        return_value=(MagicMock(name="engine_request"), {})
    )
132

133
134
135
136
    return mock_engine, serving_completion


@pytest.mark.asyncio
137
async def test_serving_completion_with_lora_resolver(mock_serving_setup, monkeypatch):
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
    monkeypatch.setenv("VLLM_ALLOW_RUNTIME_LORA_UPDATING", "true")

    mock_engine, serving_completion = mock_serving_setup

    lora_model_name = "test-lora"
    req_found = CompletionRequest(
        model=lora_model_name,
        prompt="Generate with LoRA",
    )

    # Suppress potential errors during the mocked generate call,
    # as we are primarily checking for add_lora and generate calls
    with suppress(Exception):
        await serving_completion.create_completion(req_found)

153
    mock_engine.add_lora.assert_awaited_once()
154
155
156
157
158
    called_lora_request = mock_engine.add_lora.call_args[0][0]
    assert isinstance(called_lora_request, LoRARequest)
    assert called_lora_request.lora_name == lora_model_name

    mock_engine.generate.assert_called_once()
159
    called_lora_request = mock_engine.generate.call_args[1]["lora_request"]
160
161
162
163
164
    assert isinstance(called_lora_request, LoRARequest)
    assert called_lora_request.lora_name == lora_model_name


@pytest.mark.asyncio
165
async def test_serving_completion_resolver_not_found(mock_serving_setup, monkeypatch):
166
167
168
169
170
171
172
173
174
175
176
177
    monkeypatch.setenv("VLLM_ALLOW_RUNTIME_LORA_UPDATING", "true")

    mock_engine, serving_completion = mock_serving_setup

    non_existent_model = "non-existent-lora-adapter"
    req = CompletionRequest(
        model=non_existent_model,
        prompt="what is 1+1?",
    )

    response = await serving_completion.create_completion(req)

178
    mock_engine.add_lora.assert_not_awaited()
179
180
181
    mock_engine.generate.assert_not_called()

    assert isinstance(response, ErrorResponse)
182
183
    assert response.error.code == HTTPStatus.NOT_FOUND.value
    assert non_existent_model in response.error.message
184
185
186
187


@pytest.mark.asyncio
async def test_serving_completion_resolver_add_lora_fails(
188
189
    mock_serving_setup, monkeypatch
):
190
191
192
193
194
195
196
197
198
199
200
201
202
    monkeypatch.setenv("VLLM_ALLOW_RUNTIME_LORA_UPDATING", "true")

    mock_engine, serving_completion = mock_serving_setup

    invalid_model = "invalid-lora"
    req = CompletionRequest(
        model=invalid_model,
        prompt="what is 1+1?",
    )

    response = await serving_completion.create_completion(req)

    # Assert add_lora was called before the failure
203
    mock_engine.add_lora.assert_awaited_once()
204
205
206
207
208
209
210
211
212
    called_lora_request = mock_engine.add_lora.call_args[0][0]
    assert isinstance(called_lora_request, LoRARequest)
    assert called_lora_request.lora_name == invalid_model

    # Assert generate was *not* called due to the failure
    mock_engine.generate.assert_not_called()

    # Assert the correct error response
    assert isinstance(response, ErrorResponse)
213
214
    assert response.error.code == HTTPStatus.BAD_REQUEST.value
    assert invalid_model in response.error.message
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230


@pytest.mark.asyncio
async def test_serving_completion_flag_not_set(mock_serving_setup):
    mock_engine, serving_completion = mock_serving_setup

    lora_model_name = "test-lora"
    req_found = CompletionRequest(
        model=lora_model_name,
        prompt="Generate with LoRA",
    )

    await serving_completion.create_completion(req_found)

    mock_engine.add_lora.assert_not_called()
    mock_engine.generate.assert_not_called()