conftest.py 8.97 KB
Newer Older
1
2
import tempfile
from collections import OrderedDict
3
from typing import Dict, List, TypedDict
4
from unittest.mock import MagicMock, patch
5
6

import pytest
7
import safetensors
8
9
10
11
12
13
import torch
import torch.nn as nn
from huggingface_hub import snapshot_download

import vllm
from vllm.config import LoRAConfig
14
from vllm.distributed import (cleanup_dist_env_and_memory,
15
16
                              init_distributed_environment,
                              initialize_model_parallel)
17
18
19
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               MergedColumnParallelLinear,
                                               RowParallelLinear)
20
21
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler
22
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
23
from vllm.model_executor.model_loader import get_model
24

25
26
27
28
29
30
31
32
33
34
35
36

class ContextIDInfo(TypedDict):
    lora_id: int
    context_length: str


class ContextInfo(TypedDict):
    lora: str
    context_length: str


LONG_LORA_INFOS: List[ContextIDInfo] = [{
37
38
39
40
41
42
43
44
45
46
    "lora_id": 1,
    "context_length": "16k",
}, {
    "lora_id": 2,
    "context_length": "16k",
}, {
    "lora_id": 3,
    "context_length": "32k",
}]

47

48
49
50
51
52
53
54
@pytest.fixture()
def should_do_global_cleanup_after_test(request) -> bool:
    """Allow subdirectories to skip global cleanup by overriding this fixture.
    This can provide a ~10x speedup for non-GPU unit tests since they don't need
    to initialize torch.
    """

55
    return not request.node.get_closest_marker("skip_global_cleanup")
56
57


58
@pytest.fixture(autouse=True)
59
def cleanup_fixture(should_do_global_cleanup_after_test: bool):
60
    yield
61
    if should_do_global_cleanup_after_test:
62
        cleanup_dist_env_and_memory(shutdown_ray=True)
63
64
65
66


@pytest.fixture
def dist_init():
67
68
69
70
71
72
73
74
    temp_file = tempfile.mkstemp()[1]
    init_distributed_environment(
        world_size=1,
        rank=0,
        distributed_init_method=f"file://{temp_file}",
        local_rank=0,
        backend="nccl",
    )
75
76
    initialize_model_parallel(1, 1)
    yield
77
    cleanup_dist_env_and_memory(shutdown_ray=True)
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111


@pytest.fixture
def dist_init_torch_only():
    if torch.distributed.is_initialized():
        return
    temp_file = tempfile.mkstemp()[1]
    torch.distributed.init_process_group(
        backend="nccl",
        world_size=1,
        rank=0,
        init_method=f"file://{temp_file}",
    )


@pytest.fixture
def dummy_model() -> nn.Module:
    model = nn.Sequential(
        OrderedDict([
            ("dense1", ColumnParallelLinear(764, 100)),
            ("dense2", RowParallelLinear(100, 50)),
            (
                "layer1",
                nn.Sequential(
                    OrderedDict([
                        ("dense1", ColumnParallelLinear(100, 10)),
                        ("dense2", RowParallelLinear(10, 50)),
                    ])),
            ),
            ("act2", nn.ReLU()),
            ("output", ColumnParallelLinear(50, 10)),
            ("outact", nn.Sigmoid()),
            # Special handling for lm_head & sampler
            ("lm_head", ParallelLMHead(512, 10)),
112
113
            ("logits_processor", LogitsProcessor(512)),
            ("sampler", Sampler())
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
        ]))
    model.config = MagicMock()
    return model


@pytest.fixture
def dummy_model_gate_up() -> nn.Module:
    model = nn.Sequential(
        OrderedDict([
            ("dense1", ColumnParallelLinear(764, 100)),
            ("dense2", RowParallelLinear(100, 50)),
            (
                "layer1",
                nn.Sequential(
                    OrderedDict([
                        ("dense1", ColumnParallelLinear(100, 10)),
                        ("dense2", RowParallelLinear(10, 50)),
                    ])),
            ),
            ("act2", nn.ReLU()),
            ("gate_up_proj", MergedColumnParallelLinear(50, [5, 5])),
            ("outact", nn.Sigmoid()),
            # Special handling for lm_head & sampler
            ("lm_head", ParallelLMHead(512, 10)),
138
139
            ("logits_processor", LogitsProcessor(512)),
            ("sampler", Sampler())
140
141
142
143
144
145
        ]))
    model.config = MagicMock()
    return model


@pytest.fixture(scope="session")
146
147
148
149
150
151
152
153
def sql_lora_huggingface_id():
    # huggingface repo id is used to test lora runtime downloading.
    return "yard1/llama-2-7b-sql-lora-test"


@pytest.fixture(scope="session")
def sql_lora_files(sql_lora_huggingface_id):
    return snapshot_download(repo_id=sql_lora_huggingface_id)
154
155


156
157
158
159
160
@pytest.fixture(scope="session")
def lora_bias_files():
    return snapshot_download(repo_id="followumesh/granite-3b-lora8-bias")


Terry's avatar
Terry committed
161
162
@pytest.fixture(scope="session")
def mixtral_lora_files():
163
164
165
    # Note: this module has incorrect adapter_config.json to test
    # https://github.com/vllm-project/vllm/pull/5909/files.
    return snapshot_download(repo_id="SangBinCho/mixtral-lora")
Terry's avatar
Terry committed
166
167


168
169
170
171
172
@pytest.fixture(scope="session")
def mixtral_lora_files_all_target_modules():
    return snapshot_download(repo_id="dyang415/mixtral-lora-v0")


173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
@pytest.fixture(scope="session")
def jamba_lora_files():
    #   some of the adapters have unnecessary weights for serving,
    #   hence we remove them
    def remove_unnecessary_weights(path):
        lora_path = f"{adapter_path}/adapter_model.safetensors"
        tensors = safetensors.torch.load_file(lora_path)
        nonlora_keys = []
        for k in list(tensors.keys()):
            if "lora" not in k:
                nonlora_keys.append(k)
        for k in nonlora_keys:
            del tensors[k]
        safetensors.torch.save_file(tensors, lora_path)

    adapter_path = snapshot_download(
        repo_id=
        "hf-100/Jamba-1.5-mini-Spellbound-StoryWriter-0.1-6583896-ckpt53-lora")

    remove_unnecessary_weights(adapter_path)
    return adapter_path


196
197
198
199
200
@pytest.fixture(scope="session")
def gemma_lora_files():
    return snapshot_download(repo_id="wskwon/gemma-7b-test-lora")


201
202
203
204
205
206
207
208
209
210
@pytest.fixture(scope="session")
def chatglm3_lora_files():
    return snapshot_download(repo_id="jeeejeee/chatglm3-text2sql-spider")


@pytest.fixture(scope="session")
def baichuan_lora_files():
    return snapshot_download(repo_id="jeeejeee/baichuan7b-text2sql-spider")


211
212
213
214
215
216
@pytest.fixture(scope="session")
def baichuan_zero_lora_files():
    # all the lora_B weights are initialized to zero.
    return snapshot_download(repo_id="jeeejeee/baichuan7b-zero-init")


217
218
219
220
221
@pytest.fixture(scope="session")
def baichuan_regex_lora_files():
    return snapshot_download(repo_id="jeeejeee/baichuan-7b-lora-zero-regex")


222
223
224
225
226
@pytest.fixture(scope="session")
def minicpmv_lora_files():
    return snapshot_download(repo_id="jeeejeee/minicpmv25-lora-pokemon")


227
228
229
230
231
@pytest.fixture(scope="session")
def qwen2vl_lora_files():
    return snapshot_download(repo_id="jeeejeee/qwen2-vl-lora-pokemon")


232
233
234
235
236
@pytest.fixture(scope="session")
def tinyllama_lora_files():
    return snapshot_download(repo_id="jashing/tinyllama-colorist-lora")


237
238
239
240
241
@pytest.fixture(scope="session")
def phi2_lora_files():
    return snapshot_download(repo_id="isotr0py/phi-2-test-sql-lora")


242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
@pytest.fixture(scope="session")
def long_context_lora_files_16k_1():
    return snapshot_download(repo_id="SangBinCho/long_context_16k_testing_1")


@pytest.fixture(scope="session")
def long_context_lora_files_16k_2():
    return snapshot_download(repo_id="SangBinCho/long_context_16k_testing_2")


@pytest.fixture(scope="session")
def long_context_lora_files_32k():
    return snapshot_download(repo_id="SangBinCho/long_context_32k_testing")


@pytest.fixture(scope="session")
def long_context_infos(long_context_lora_files_16k_1,
                       long_context_lora_files_16k_2,
                       long_context_lora_files_32k):
261
    cleanup_dist_env_and_memory(shutdown_ray=True)
262
    infos: Dict[int, ContextInfo] = {}
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
    for lora_checkpoint_info in LONG_LORA_INFOS:
        lora_id = lora_checkpoint_info["lora_id"]
        if lora_id == 1:
            lora = long_context_lora_files_16k_1
        elif lora_id == 2:
            lora = long_context_lora_files_16k_2
        elif lora_id == 3:
            lora = long_context_lora_files_32k
        else:
            raise AssertionError("Unknown lora id")
        infos[lora_id] = {
            "context_length": lora_checkpoint_info["context_length"],
            "lora": lora,
        }
    return infos


280
@pytest.fixture
281
def llama_2_7b_engine_extra_embeddings():
282
    cleanup_dist_env_and_memory(shutdown_ray=True)
283
284
    get_model_old = get_model

285
286
287
288
    def get_model_patched(**kwargs):
        kwargs["vllm_config"].lora_config = LoRAConfig(max_loras=4,
                                                       max_lora_rank=8)
        return get_model_old(**kwargs)
289
290
291
292
293

    with patch("vllm.worker.model_runner.get_model", get_model_patched):
        engine = vllm.LLM("meta-llama/Llama-2-7b-hf", enable_lora=False)
    yield engine.llm_engine
    del engine
294
    cleanup_dist_env_and_memory(shutdown_ray=True)
295
296
297


@pytest.fixture
298
def llama_2_7b_model_extra_embeddings(llama_2_7b_engine_extra_embeddings):
299
300
    yield (llama_2_7b_engine_extra_embeddings.model_executor.driver_worker.
           model_runner.model)