conftest.py 7.16 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
import tempfile
from collections import OrderedDict
6
from unittest.mock import MagicMock
7
8
9
10
11
12

import pytest
import torch
import torch.nn as nn
from huggingface_hub import snapshot_download

13
from vllm.distributed import (cleanup_dist_env_and_memory,
14
15
                              init_distributed_environment,
                              initialize_model_parallel)
16
17
18
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               MergedColumnParallelLinear,
                                               RowParallelLinear)
19
20
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler
21
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
22
from vllm.model_executor.models.interfaces import SupportsLoRA
23
from vllm.platforms import current_platform
24

25

26
27
28
29
30
31
32
@pytest.fixture()
def should_do_global_cleanup_after_test(request) -> bool:
    """Allow subdirectories to skip global cleanup by overriding this fixture.
    This can provide a ~10x speedup for non-GPU unit tests since they don't need
    to initialize torch.
    """

33
    return not request.node.get_closest_marker("skip_global_cleanup")
34
35


36
@pytest.fixture(autouse=True)
37
def cleanup_fixture(should_do_global_cleanup_after_test: bool):
38
    yield
39
    if should_do_global_cleanup_after_test:
40
        cleanup_dist_env_and_memory(shutdown_ray=True)
41
42
43
44


@pytest.fixture
def dist_init():
45
    temp_file = tempfile.mkstemp()[1]
46
47

    backend = "nccl"
48
    if current_platform.is_cpu() or current_platform.is_tpu():
49
50
51
52
53
54
55
        backend = "gloo"

    init_distributed_environment(world_size=1,
                                 rank=0,
                                 distributed_init_method=f"file://{temp_file}",
                                 local_rank=0,
                                 backend=backend)
56
57
    initialize_model_parallel(1, 1)
    yield
58
    cleanup_dist_env_and_memory(shutdown_ray=True)
59
60
61
62
63
64


@pytest.fixture
def dist_init_torch_only():
    if torch.distributed.is_initialized():
        return
65
66
67
68
    backend = "nccl"
    if current_platform.is_cpu():
        backend = "gloo"

69
    temp_file = tempfile.mkstemp()[1]
70
71
72
73
    torch.distributed.init_process_group(world_size=1,
                                         rank=0,
                                         init_method=f"file://{temp_file}",
                                         backend=backend)
74
75


76
77
78
79
class DummyLoRAModel(nn.Sequential, SupportsLoRA):
    pass


80
81
@pytest.fixture
def dummy_model() -> nn.Module:
82
    model = DummyLoRAModel(
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
        OrderedDict([
            ("dense1", ColumnParallelLinear(764, 100)),
            ("dense2", RowParallelLinear(100, 50)),
            (
                "layer1",
                nn.Sequential(
                    OrderedDict([
                        ("dense1", ColumnParallelLinear(100, 10)),
                        ("dense2", RowParallelLinear(10, 50)),
                    ])),
            ),
            ("act2", nn.ReLU()),
            ("output", ColumnParallelLinear(50, 10)),
            ("outact", nn.Sigmoid()),
            # Special handling for lm_head & sampler
            ("lm_head", ParallelLMHead(512, 10)),
99
100
            ("logits_processor", LogitsProcessor(512)),
            ("sampler", Sampler())
101
102
        ]))
    model.config = MagicMock()
103
    model.embedding_modules = {"lm_head": "lm_head"}
104
    model.unpadded_vocab_size = 32000
105
106
107
108
109
    return model


@pytest.fixture
def dummy_model_gate_up() -> nn.Module:
110
    model = DummyLoRAModel(
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
        OrderedDict([
            ("dense1", ColumnParallelLinear(764, 100)),
            ("dense2", RowParallelLinear(100, 50)),
            (
                "layer1",
                nn.Sequential(
                    OrderedDict([
                        ("dense1", ColumnParallelLinear(100, 10)),
                        ("dense2", RowParallelLinear(10, 50)),
                    ])),
            ),
            ("act2", nn.ReLU()),
            ("gate_up_proj", MergedColumnParallelLinear(50, [5, 5])),
            ("outact", nn.Sigmoid()),
            # Special handling for lm_head & sampler
            ("lm_head", ParallelLMHead(512, 10)),
127
128
            ("logits_processor", LogitsProcessor(512)),
            ("sampler", Sampler())
129
130
        ]))
    model.config = MagicMock()
131
132
133
134
135
136
137
    model.packed_modules_mapping = {
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }
    model.embedding_modules = {"lm_head": "lm_head"}
138
139
    model.unpadded_vocab_size = 32000

140
141
142
    return model


143
144
145
146
147
148
@pytest.fixture(scope="session")
def llama_2_7b_base_huggingface_id():
    # used as a base model for testing with sql lora adapter
    return "meta-llama/Llama-2-7b-hf"


149
@pytest.fixture(scope="session")
150
151
152
153
154
155
156
157
def sql_lora_huggingface_id():
    # huggingface repo id is used to test lora runtime downloading.
    return "yard1/llama-2-7b-sql-lora-test"


@pytest.fixture(scope="session")
def sql_lora_files(sql_lora_huggingface_id):
    return snapshot_download(repo_id=sql_lora_huggingface_id)
158
159


Terry's avatar
Terry committed
160
161
@pytest.fixture(scope="session")
def mixtral_lora_files():
162
163
164
    # Note: this module has incorrect adapter_config.json to test
    # https://github.com/vllm-project/vllm/pull/5909/files.
    return snapshot_download(repo_id="SangBinCho/mixtral-lora")
Terry's avatar
Terry committed
165
166


167
168
169
170
171
172
173
174
175
176
@pytest.fixture(scope="session")
def chatglm3_lora_files():
    return snapshot_download(repo_id="jeeejeee/chatglm3-text2sql-spider")


@pytest.fixture(scope="session")
def baichuan_lora_files():
    return snapshot_download(repo_id="jeeejeee/baichuan7b-text2sql-spider")


177
178
179
180
181
182
@pytest.fixture(scope="session")
def baichuan_zero_lora_files():
    # all the lora_B weights are initialized to zero.
    return snapshot_download(repo_id="jeeejeee/baichuan7b-zero-init")


183
184
185
186
187
@pytest.fixture(scope="session")
def baichuan_regex_lora_files():
    return snapshot_download(repo_id="jeeejeee/baichuan-7b-lora-zero-regex")


188
189
190
191
192
@pytest.fixture(scope="session")
def ilama_lora_files():
    return snapshot_download(repo_id="jeeejeee/ilama-text2sql-spider")


193
194
195
196
197
@pytest.fixture(scope="session")
def minicpmv_lora_files():
    return snapshot_download(repo_id="jeeejeee/minicpmv25-lora-pokemon")


198
199
200
201
202
@pytest.fixture(scope="session")
def qwen2vl_lora_files():
    return snapshot_download(repo_id="jeeejeee/qwen2-vl-lora-pokemon")


203
204
205
206
207
208
@pytest.fixture(scope="session")
def qwen25vl_base_huggingface_id():
    # used as a base model for testing with qwen25vl lora adapter
    return "Qwen/Qwen2.5-VL-3B-Instruct"


209
210
211
212
213
@pytest.fixture(scope="session")
def qwen25vl_lora_files():
    return snapshot_download(repo_id="jeeejeee/qwen25-vl-lora-pokemon")


214
215
216
217
218
@pytest.fixture(scope="session")
def tinyllama_lora_files():
    return snapshot_download(repo_id="jashing/tinyllama-colorist-lora")


219
220
221
222
223
@pytest.fixture(scope="session")
def phi2_lora_files():
    return snapshot_download(repo_id="isotr0py/phi-2-test-sql-lora")


224
225
226
@pytest.fixture
def reset_default_device():
    """
227
228
    Some tests, such as `test_punica_ops.py`, explicitly set the
    default device, which can affect subsequent tests. Adding this fixture
229
230
231
232
233
    helps avoid this problem.
    """
    original_device = torch.get_default_device()
    yield
    torch.set_default_device(original_device)