conftest.py 9.71 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
import tempfile
from collections import OrderedDict
6
from importlib import reload
7
from unittest.mock import MagicMock
8
9
10
11
12
13

import pytest
import torch
import torch.nn as nn
from huggingface_hub import snapshot_download

14
15
16
17
18
19
20
21
22
23
from vllm.distributed import (
    cleanup_dist_env_and_memory,
    init_distributed_environment,
    initialize_model_parallel,
)
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    MergedColumnParallelLinear,
    RowParallelLinear,
)
24
from vllm.model_executor.layers.logits_processor import LogitsProcessor
25
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
26
from vllm.model_executor.models.interfaces import SupportsLoRA
27
from vllm.platforms import current_platform
28

29

30
31
32
33
34
35
36
@pytest.fixture()
def should_do_global_cleanup_after_test(request) -> bool:
    """Allow subdirectories to skip global cleanup by overriding this fixture.
    This can provide a ~10x speedup for non-GPU unit tests since they don't need
    to initialize torch.
    """

37
    return not request.node.get_closest_marker("skip_global_cleanup")
38
39


40
@pytest.fixture(autouse=True)
41
def cleanup_fixture(should_do_global_cleanup_after_test: bool):
42
    yield
43
    if should_do_global_cleanup_after_test:
44
        cleanup_dist_env_and_memory(shutdown_ray=True)
45
46


47
48
49
50
51
52
53
54
55
56
57
58
@pytest.fixture
def maybe_enable_lora_dual_stream(monkeypatch: pytest.MonkeyPatch):
    if current_platform.is_cuda():
        monkeypatch.setenv("VLLM_LORA_ENABLE_DUAL_STREAM", "1")
        import vllm.lora.layers.base_linear

        if not hasattr(vllm.lora.layers.base_linear, "lora_linear_async"):
            # Reload the module to ensure the environment variable takes effect.
            reload(vllm.lora.layers.base_linear)
    yield


59
60
@pytest.fixture
def dist_init():
61
62
    from tests.utils import ensure_current_vllm_config

63
    temp_file = tempfile.mkstemp()[1]
64
65

    backend = "nccl"
66
    if current_platform.is_cpu() or current_platform.is_tpu():
67
68
        backend = "gloo"

69
70
71
72
73
74
75
76
77
78
    with ensure_current_vllm_config():
        init_distributed_environment(
            world_size=1,
            rank=0,
            distributed_init_method=f"file://{temp_file}",
            local_rank=0,
            backend=backend,
        )
        initialize_model_parallel(1, 1)
        yield
79
    cleanup_dist_env_and_memory(shutdown_ray=True)
80
81
82
83
84
85


@pytest.fixture
def dist_init_torch_only():
    if torch.distributed.is_initialized():
        return
86
87
88
89
    backend = "nccl"
    if current_platform.is_cpu():
        backend = "gloo"

90
    temp_file = tempfile.mkstemp()[1]
91
92
93
    torch.distributed.init_process_group(
        world_size=1, rank=0, init_method=f"file://{temp_file}", backend=backend
    )
94
95


96
97
98
99
class DummyLoRAModel(nn.Sequential, SupportsLoRA):
    pass


100
@pytest.fixture
101
def dummy_model(default_vllm_config) -> nn.Module:
102
    model = DummyLoRAModel(
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
        OrderedDict(
            [
                ("dense1", ColumnParallelLinear(764, 100)),
                ("dense2", RowParallelLinear(100, 50)),
                (
                    "layer1",
                    nn.Sequential(
                        OrderedDict(
                            [
                                ("dense1", ColumnParallelLinear(100, 10)),
                                ("dense2", RowParallelLinear(10, 50)),
                            ]
                        )
                    ),
                ),
                ("act2", nn.ReLU()),
                ("output", ColumnParallelLinear(50, 10)),
                ("outact", nn.Sigmoid()),
                # Special handling for lm_head & sampler
122
123
                ("lm_head", ParallelLMHead(32064, 10)),
                ("logits_processor", LogitsProcessor(32064)),
124
125
126
            ]
        )
    )
127
    model.config = MagicMock()
128
    model.embedding_modules = {"lm_head": "lm_head"}
129
    model.unpadded_vocab_size = 32064
130
131
132
133
    return model


@pytest.fixture
134
def dummy_model_gate_up(default_vllm_config) -> nn.Module:
135
    model = DummyLoRAModel(
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
        OrderedDict(
            [
                ("dense1", ColumnParallelLinear(764, 100)),
                ("dense2", RowParallelLinear(100, 50)),
                (
                    "layer1",
                    nn.Sequential(
                        OrderedDict(
                            [
                                ("dense1", ColumnParallelLinear(100, 10)),
                                ("dense2", RowParallelLinear(10, 50)),
                            ]
                        )
                    ),
                ),
                ("act2", nn.ReLU()),
                ("gate_up_proj", MergedColumnParallelLinear(50, [5, 5])),
                ("outact", nn.Sigmoid()),
                # Special handling for lm_head & sampler
155
156
                ("lm_head", ParallelLMHead(32064, 10)),
                ("logits_processor", LogitsProcessor(32064)),
157
158
159
            ]
        )
    )
160
    model.config = MagicMock()
161
162
163
164
165
166
167
    model.packed_modules_mapping = {
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }
    model.embedding_modules = {"lm_head": "lm_head"}
168
    model.unpadded_vocab_size = 32064
169

170
171
172
    return model


Terry's avatar
Terry committed
173
174
@pytest.fixture(scope="session")
def mixtral_lora_files():
175
176
177
    # Note: this module has incorrect adapter_config.json to test
    # https://github.com/vllm-project/vllm/pull/5909/files.
    return snapshot_download(repo_id="SangBinCho/mixtral-lora")
Terry's avatar
Terry committed
178
179


180
181
182
183
184
185
186
187
188
189
@pytest.fixture(scope="session")
def chatglm3_lora_files():
    return snapshot_download(repo_id="jeeejeee/chatglm3-text2sql-spider")


@pytest.fixture(scope="session")
def baichuan_lora_files():
    return snapshot_download(repo_id="jeeejeee/baichuan7b-text2sql-spider")


190
191
192
193
194
195
@pytest.fixture(scope="session")
def baichuan_zero_lora_files():
    # all the lora_B weights are initialized to zero.
    return snapshot_download(repo_id="jeeejeee/baichuan7b-zero-init")


196
197
198
199
200
@pytest.fixture(scope="session")
def baichuan_regex_lora_files():
    return snapshot_download(repo_id="jeeejeee/baichuan-7b-lora-zero-regex")


201
202
203
204
205
@pytest.fixture(scope="session")
def ilama_lora_files():
    return snapshot_download(repo_id="jeeejeee/ilama-text2sql-spider")


206
207
208
209
210
@pytest.fixture(scope="session")
def minicpmv_lora_files():
    return snapshot_download(repo_id="jeeejeee/minicpmv25-lora-pokemon")


211
212
213
214
215
@pytest.fixture(scope="session")
def qwen2vl_lora_files():
    return snapshot_download(repo_id="jeeejeee/qwen2-vl-lora-pokemon")


216
217
218
219
220
221
@pytest.fixture(scope="session")
def qwen25vl_base_huggingface_id():
    # used as a base model for testing with qwen25vl lora adapter
    return "Qwen/Qwen2.5-VL-3B-Instruct"


222
223
224
225
226
@pytest.fixture(scope="session")
def qwen25vl_lora_files():
    return snapshot_download(repo_id="jeeejeee/qwen25-vl-lora-pokemon")


227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
@pytest.fixture(scope="session")
def qwen2vl_language_lora_files():
    return snapshot_download(repo_id="prashanth058/qwen2vl-flickr-lora-language")


@pytest.fixture(scope="session")
def qwen2vl_vision_tower_connector_lora_files():
    return snapshot_download(repo_id="prashanth058/qwen2vl-flickr-lora-tower-connector")


@pytest.fixture(scope="session")
def qwen2vl_vision_tower_lora_files():
    return snapshot_download(repo_id="prashanth058/qwen2vl-flickr-lora-tower")


@pytest.fixture(scope="session")
def qwen25vl_vision_lora_files():
    return snapshot_download(repo_id="EpochEcho/qwen2.5-3b-vl-lora-vision-connector")


@pytest.fixture(scope="session")
def qwen3vl_vision_lora_files():
    return snapshot_download(repo_id="EpochEcho/qwen3-4b-vl-lora-vision-connector")


252
253
254
255
256
257
258
259
260
261
262
263
@pytest.fixture(scope="session")
def qwen3_meowing_lora_files():
    """Download Qwen3 Meow LoRA files once per test session."""
    return snapshot_download(repo_id="Jackmin108/Qwen3-0.6B-Meow-LoRA")


@pytest.fixture(scope="session")
def qwen3_woofing_lora_files():
    """Download Qwen3 Woof LoRA files once per test session."""
    return snapshot_download(repo_id="Jackmin108/Qwen3-0.6B-Woof-LoRA")


264
265
266
267
268
@pytest.fixture(scope="session")
def tinyllama_lora_files():
    return snapshot_download(repo_id="jashing/tinyllama-colorist-lora")


269
270
271
272
273
274
275
@pytest.fixture(scope="session")
def deepseekv2_lora_files():
    return snapshot_download(repo_id="wuchen01/DeepSeek-V2-Lite-Chat-All-LoRA")


@pytest.fixture(scope="session")
def gptoss20b_lora_files():
276
    return snapshot_download(repo_id="jeeejeee/gpt-oss-20b-lora-adapter-text2sql")
277
278
279
280
281
282
283
284
285
286
287
288


@pytest.fixture(scope="session")
def qwen3moe_lora_files():
    return snapshot_download(repo_id="jeeejeee/qwen3-moe-text2sql-spider")


@pytest.fixture(scope="session")
def olmoe_lora_files():
    return snapshot_download(repo_id="jeeejeee/olmoe-instruct-text2sql-spider")


289
290
291
292
293
294
@pytest.fixture(scope="session")
def qwen3_lora_files():
    return snapshot_download(repo_id="charent/self_cognition_Alice")


@pytest.fixture(scope="session")
295
296
297
298
299
300
301
302
def llama32_lora_huggingface_id():
    # huggingface repo id is used to test lora runtime downloading.
    return "jeeejeee/llama32-3b-text2sql-spider"


@pytest.fixture(scope="session")
def llama32_lora_files(llama32_lora_huggingface_id):
    return snapshot_download(repo_id=llama32_lora_huggingface_id)
303
304


305
306
307
308
309
@pytest.fixture(scope="session")
def whisper_lora_files():
    return snapshot_download(repo_id="chengyili2005/whisper-small-mandarin-lora")


310
@pytest.fixture(scope="session")
311
def qwen35_text_lora_files():
312
313
314
    return snapshot_download(repo_id="jeeejeee/qwen35-4b-text-only-sql-lora")


315
316
317
318
319
@pytest.fixture(scope="session")
def qwen35_vl_lora_files():
    return snapshot_download(repo_id="jeeejeee/qwen35-4b-all-linear-pokemon-lora")


320
321
322
@pytest.fixture
def reset_default_device():
    """
323
324
    Some tests, such as `test_punica_ops.py`, explicitly set the
    default device, which can affect subsequent tests. Adding this fixture
325
326
327
328
329
    helps avoid this problem.
    """
    original_device = torch.get_default_device()
    yield
    torch.set_default_device(original_device)