conftest.py 8.91 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
import tempfile
from collections import OrderedDict
5
from unittest.mock import MagicMock, patch
6
7

import pytest
8
import os
9
10
11
12
13
14
import torch
import torch.nn as nn
from huggingface_hub import snapshot_download

import vllm
from vllm.config import LoRAConfig
15
from vllm.distributed import (cleanup_dist_env_and_memory,
16
17
                              init_distributed_environment,
                              initialize_model_parallel)
18
19
20
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               MergedColumnParallelLinear,
                                               RowParallelLinear)
21
22
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler
23
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
24
from vllm.model_executor.model_loader import get_model
25
from vllm.model_executor.models.interfaces import SupportsLoRA
26
from vllm.platforms import current_platform
27
from ..utils import models_path_prefix
28

29

30
31
32
33
34
35
36
@pytest.fixture()
def should_do_global_cleanup_after_test(request) -> bool:
    """Allow subdirectories to skip global cleanup by overriding this fixture.
    This can provide a ~10x speedup for non-GPU unit tests since they don't need
    to initialize torch.
    """

37
    return not request.node.get_closest_marker("skip_global_cleanup")
38
39


40
@pytest.fixture(autouse=True)
41
def cleanup_fixture(should_do_global_cleanup_after_test: bool):
42
    yield
43
    if should_do_global_cleanup_after_test:
44
        cleanup_dist_env_and_memory(shutdown_ray=True)
45
46
47
48


@pytest.fixture
def dist_init():
49
    temp_file = tempfile.mkstemp()[1]
50
51
52
53
54
55
56
57
58
59

    backend = "nccl"
    if current_platform.is_cpu():
        backend = "gloo"

    init_distributed_environment(world_size=1,
                                 rank=0,
                                 distributed_init_method=f"file://{temp_file}",
                                 local_rank=0,
                                 backend=backend)
60
61
    initialize_model_parallel(1, 1)
    yield
62
    cleanup_dist_env_and_memory(shutdown_ray=True)
63
64
65
66
67
68


@pytest.fixture
def dist_init_torch_only():
    if torch.distributed.is_initialized():
        return
69
70
71
72
    backend = "nccl"
    if current_platform.is_cpu():
        backend = "gloo"

73
    temp_file = tempfile.mkstemp()[1]
74
75
76
77
    torch.distributed.init_process_group(world_size=1,
                                         rank=0,
                                         init_method=f"file://{temp_file}",
                                         backend=backend)
78
79


80
81
82
83
class DummyLoRAModel(nn.Sequential, SupportsLoRA):
    pass


84
85
@pytest.fixture
def dummy_model() -> nn.Module:
86
    model = DummyLoRAModel(
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
        OrderedDict([
            ("dense1", ColumnParallelLinear(764, 100)),
            ("dense2", RowParallelLinear(100, 50)),
            (
                "layer1",
                nn.Sequential(
                    OrderedDict([
                        ("dense1", ColumnParallelLinear(100, 10)),
                        ("dense2", RowParallelLinear(10, 50)),
                    ])),
            ),
            ("act2", nn.ReLU()),
            ("output", ColumnParallelLinear(50, 10)),
            ("outact", nn.Sigmoid()),
            # Special handling for lm_head & sampler
            ("lm_head", ParallelLMHead(512, 10)),
103
104
            ("logits_processor", LogitsProcessor(512)),
            ("sampler", Sampler())
105
106
        ]))
    model.config = MagicMock()
107
    model.embedding_modules = {"lm_head": "lm_head"}
108
109
110
111
112
    return model


@pytest.fixture
def dummy_model_gate_up() -> nn.Module:
113
    model = DummyLoRAModel(
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
        OrderedDict([
            ("dense1", ColumnParallelLinear(764, 100)),
            ("dense2", RowParallelLinear(100, 50)),
            (
                "layer1",
                nn.Sequential(
                    OrderedDict([
                        ("dense1", ColumnParallelLinear(100, 10)),
                        ("dense2", RowParallelLinear(10, 50)),
                    ])),
            ),
            ("act2", nn.ReLU()),
            ("gate_up_proj", MergedColumnParallelLinear(50, [5, 5])),
            ("outact", nn.Sigmoid()),
            # Special handling for lm_head & sampler
            ("lm_head", ParallelLMHead(512, 10)),
130
131
            ("logits_processor", LogitsProcessor(512)),
            ("sampler", Sampler())
132
133
        ]))
    model.config = MagicMock()
134
135
136
137
138
139
140
    model.packed_modules_mapping = {
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }
    model.embedding_modules = {"lm_head": "lm_head"}
141
142
143
144
    return model


@pytest.fixture(scope="session")
145
146
def sql_lora_huggingface_id():
    # huggingface repo id is used to test lora runtime downloading.
147
    return os.path.join(models_path_prefix, "yard1/llama-2-7b-sql-lora-test")
148
149
150
151
152


@pytest.fixture(scope="session")
def sql_lora_files(sql_lora_huggingface_id):
    return snapshot_download(repo_id=sql_lora_huggingface_id)
153
154


Terry's avatar
Terry committed
155
156
@pytest.fixture(scope="session")
def mixtral_lora_files():
157
158
    # Note: this module has incorrect adapter_config.json to test
    # https://github.com/vllm-project/vllm/pull/5909/files.
159
160
    # return snapshot_download(repo_id="SangBinCho/mixtral-lora")
    return os.path.join(models_path_prefix, "SangBinCho/mixtral-lora")
Terry's avatar
Terry committed
161
162


163
164
@pytest.fixture(scope="session")
def gemma_lora_files():
165
166
    # return snapshot_download(repo_id="wskwon/gemma-7b-test-lora")
    return os.path.join(models_path_prefix, "wskwon/gemma-7b-test-lora")
167
168


169
170
@pytest.fixture(scope="session")
def chatglm3_lora_files():
171
172
    # return snapshot_download(repo_id="jeeejeee/chatglm3-text2sql-spider")
    return os.path.join(models_path_prefix, "jeeejeee/chatglm3-text2sql-spider")
173
174
175
176


@pytest.fixture(scope="session")
def baichuan_lora_files():
177
178
    # return snapshot_download(repo_id="jeeejeee/baichuan7b-text2sql-spider")
    return os.path.join(models_path_prefix, "jeeejeee/baichuan7b-text2sql-spider")
179
180


181
182
183
@pytest.fixture(scope="session")
def baichuan_zero_lora_files():
    # all the lora_B weights are initialized to zero.
184
185
    # return snapshot_download(repo_id="jeeejeee/baichuan7b-zero-init")
    return os.path.join(models_path_prefix, "jeeejeee/baichuan7b-zero-init")
186
187


188
189
190
191
192
@pytest.fixture(scope="session")
def baichuan_regex_lora_files():
    return snapshot_download(repo_id="jeeejeee/baichuan-7b-lora-zero-regex")


193
194
195
196
197
@pytest.fixture(scope="session")
def ilama_lora_files():
    return snapshot_download(repo_id="jeeejeee/ilama-text2sql-spider")


198
199
200
201
202
@pytest.fixture(scope="session")
def minicpmv_lora_files():
    return snapshot_download(repo_id="jeeejeee/minicpmv25-lora-pokemon")


203
204
205
206
207
@pytest.fixture(scope="session")
def qwen2vl_lora_files():
    return snapshot_download(repo_id="jeeejeee/qwen2-vl-lora-pokemon")


208
209
210
211
212
@pytest.fixture(scope="session")
def qwen25vl_lora_files():
    return snapshot_download(repo_id="jeeejeee/qwen25-vl-lora-pokemon")


213
214
@pytest.fixture(scope="session")
def tinyllama_lora_files():
215
216
    # return snapshot_download(repo_id="jashing/tinyllama-colorist-lora")
    return os.path.join(models_path_prefix, "jashing/tinyllama-colorist-lora")
217
218


219
220
@pytest.fixture(scope="session")
def phi2_lora_files():
221
222
    # return snapshot_download(repo_id="isotr0py/phi-2-test-sql-lora")
    return os.path.join(models_path_prefix, "isotr0py/phi-2-test-sql-lora")
223

王敏's avatar
王敏 committed
224
225
226
227
228
@pytest.fixture(scope="session")
def qwen_lora_files():
    # return snapshot_download(repo_id="jeeejeee/chatglm3-text2sql-spider")
    return os.path.join(models_path_prefix, "customize/qwen-nl2dsl-lora")

229

230
231
@pytest.fixture(scope="session")
def long_context_lora_files_16k_1():
232
233
    # return snapshot_download(repo_id="SangBinCho/long_context_16k_testing_1")
    return os.path.join(models_path_prefix, "SangBinCho/long_context_16k_testing_1")
234
235


236
@pytest.fixture
237
def llama_2_7b_engine_extra_embeddings():
238
    cleanup_dist_env_and_memory(shutdown_ray=True)
239
240
    get_model_old = get_model

241
242
243
244
    def get_model_patched(**kwargs):
        kwargs["vllm_config"].lora_config = LoRAConfig(max_loras=4,
                                                       max_lora_rank=8)
        return get_model_old(**kwargs)
245
246

    with patch("vllm.worker.model_runner.get_model", get_model_patched):
247
        engine = vllm.LLM(os.path.join(models_path_prefix, "meta-llama/Llama-2-7b-hf"), enable_lora=False)
248
249
    yield engine.llm_engine
    del engine
250
    cleanup_dist_env_and_memory(shutdown_ray=True)
251
252
253


@pytest.fixture
254
def llama_2_7b_model_extra_embeddings(llama_2_7b_engine_extra_embeddings):
255
256
    yield (llama_2_7b_engine_extra_embeddings.model_executor.driver_worker.
           model_runner.model)
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273


@pytest.fixture(params=[True, False])
def run_with_both_engines_lora(request, monkeypatch):
    # Automatically runs tests twice, once with V1 and once without
    use_v1 = request.param
    # Tests decorated with `@skip_v1` are only run without v1
    skip_v1 = request.node.get_closest_marker("skip_v1")

    if use_v1:
        if skip_v1:
            pytest.skip("Skipping test on vllm V1")
        monkeypatch.setenv('VLLM_USE_V1', '1')
    else:
        monkeypatch.setenv('VLLM_USE_V1', '0')

    yield