conftest.py 10.5 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
import tempfile
from collections import OrderedDict
5
from typing import Dict, List, TypedDict
6
from unittest.mock import MagicMock, patch
7
8

import pytest
9
import os
10
import safetensors
zhuwenwen's avatar
zhuwenwen committed
11

12
13
14
15
16
17
import torch
import torch.nn as nn
from huggingface_hub import snapshot_download

import vllm
from vllm.config import LoRAConfig
18
from vllm.distributed import (cleanup_dist_env_and_memory,
19
20
                              init_distributed_environment,
                              initialize_model_parallel)
21
22
23
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               MergedColumnParallelLinear,
                                               RowParallelLinear)
24
25
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler
26
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
27
from vllm.model_executor.model_loader import get_model
28
from vllm.platforms import current_platform
29
from ..utils import models_path_prefix
30

31
32
33
34
35
36
37
38
39
40
41
42

class ContextIDInfo(TypedDict):
    lora_id: int
    context_length: str


class ContextInfo(TypedDict):
    lora: str
    context_length: str


LONG_LORA_INFOS: List[ContextIDInfo] = [{
43
44
45
46
47
48
49
50
51
52
    "lora_id": 1,
    "context_length": "16k",
}, {
    "lora_id": 2,
    "context_length": "16k",
}, {
    "lora_id": 3,
    "context_length": "32k",
}]

53

54
55
56
57
58
59
60
@pytest.fixture()
def should_do_global_cleanup_after_test(request) -> bool:
    """Allow subdirectories to skip global cleanup by overriding this fixture.
    This can provide a ~10x speedup for non-GPU unit tests since they don't need
    to initialize torch.
    """

61
    return not request.node.get_closest_marker("skip_global_cleanup")
62
63


64
@pytest.fixture(autouse=True)
65
def cleanup_fixture(should_do_global_cleanup_after_test: bool):
66
    yield
67
    if should_do_global_cleanup_after_test:
68
        cleanup_dist_env_and_memory(shutdown_ray=True)
69
70
71
72


@pytest.fixture
def dist_init():
73
    temp_file = tempfile.mkstemp()[1]
74
75
76
77
78
79
80
81
82
83

    backend = "nccl"
    if current_platform.is_cpu():
        backend = "gloo"

    init_distributed_environment(world_size=1,
                                 rank=0,
                                 distributed_init_method=f"file://{temp_file}",
                                 local_rank=0,
                                 backend=backend)
84
85
    initialize_model_parallel(1, 1)
    yield
86
    cleanup_dist_env_and_memory(shutdown_ray=True)
87
88
89
90
91
92


@pytest.fixture
def dist_init_torch_only():
    if torch.distributed.is_initialized():
        return
93
94
95
96
    backend = "nccl"
    if current_platform.is_cpu():
        backend = "gloo"

97
    temp_file = tempfile.mkstemp()[1]
98
99
100
101
    torch.distributed.init_process_group(world_size=1,
                                         rank=0,
                                         init_method=f"file://{temp_file}",
                                         backend=backend)
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122


@pytest.fixture
def dummy_model() -> nn.Module:
    model = nn.Sequential(
        OrderedDict([
            ("dense1", ColumnParallelLinear(764, 100)),
            ("dense2", RowParallelLinear(100, 50)),
            (
                "layer1",
                nn.Sequential(
                    OrderedDict([
                        ("dense1", ColumnParallelLinear(100, 10)),
                        ("dense2", RowParallelLinear(10, 50)),
                    ])),
            ),
            ("act2", nn.ReLU()),
            ("output", ColumnParallelLinear(50, 10)),
            ("outact", nn.Sigmoid()),
            # Special handling for lm_head & sampler
            ("lm_head", ParallelLMHead(512, 10)),
123
124
            ("logits_processor", LogitsProcessor(512)),
            ("sampler", Sampler())
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
        ]))
    model.config = MagicMock()
    return model


@pytest.fixture
def dummy_model_gate_up() -> nn.Module:
    model = nn.Sequential(
        OrderedDict([
            ("dense1", ColumnParallelLinear(764, 100)),
            ("dense2", RowParallelLinear(100, 50)),
            (
                "layer1",
                nn.Sequential(
                    OrderedDict([
                        ("dense1", ColumnParallelLinear(100, 10)),
                        ("dense2", RowParallelLinear(10, 50)),
                    ])),
            ),
            ("act2", nn.ReLU()),
            ("gate_up_proj", MergedColumnParallelLinear(50, [5, 5])),
            ("outact", nn.Sigmoid()),
            # Special handling for lm_head & sampler
            ("lm_head", ParallelLMHead(512, 10)),
149
150
            ("logits_processor", LogitsProcessor(512)),
            ("sampler", Sampler())
151
152
153
154
155
156
        ]))
    model.config = MagicMock()
    return model


@pytest.fixture(scope="session")
157
158
def sql_lora_huggingface_id():
    # huggingface repo id is used to test lora runtime downloading.
159
    return os.path.join(models_path_prefix, "yard1/llama-2-7b-sql-lora-test")
160
161
162
163
164


@pytest.fixture(scope="session")
def sql_lora_files(sql_lora_huggingface_id):
    return snapshot_download(repo_id=sql_lora_huggingface_id)
165
166


167
168
169
170
171
@pytest.fixture(scope="session")
def lora_bias_files():
    return snapshot_download(repo_id="followumesh/granite-3b-lora8-bias")


Terry's avatar
Terry committed
172
173
@pytest.fixture(scope="session")
def mixtral_lora_files():
174
175
    # Note: this module has incorrect adapter_config.json to test
    # https://github.com/vllm-project/vllm/pull/5909/files.
176
177
    # return snapshot_download(repo_id="SangBinCho/mixtral-lora")
    return os.path.join(models_path_prefix, "SangBinCho/mixtral-lora")
Terry's avatar
Terry committed
178
179


180
181
182
183
184
@pytest.fixture(scope="session")
def mixtral_lora_files_all_target_modules():
    return snapshot_download(repo_id="dyang415/mixtral-lora-v0")


185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
@pytest.fixture(scope="session")
def jamba_lora_files():
    #   some of the adapters have unnecessary weights for serving,
    #   hence we remove them
    def remove_unnecessary_weights(path):
        lora_path = f"{adapter_path}/adapter_model.safetensors"
        tensors = safetensors.torch.load_file(lora_path)
        nonlora_keys = []
        for k in list(tensors.keys()):
            if "lora" not in k:
                nonlora_keys.append(k)
        for k in nonlora_keys:
            del tensors[k]
        safetensors.torch.save_file(tensors, lora_path)

    adapter_path = snapshot_download(
        repo_id=
        "hf-100/Jamba-1.5-mini-Spellbound-StoryWriter-0.1-6583896-ckpt53-lora")

    remove_unnecessary_weights(adapter_path)
    return adapter_path


208
209
@pytest.fixture(scope="session")
def gemma_lora_files():
210
211
    # return snapshot_download(repo_id="wskwon/gemma-7b-test-lora")
    return os.path.join(models_path_prefix, "wskwon/gemma-7b-test-lora")
212
213


214
215
@pytest.fixture(scope="session")
def chatglm3_lora_files():
216
217
    # return snapshot_download(repo_id="jeeejeee/chatglm3-text2sql-spider")
    return os.path.join(models_path_prefix, "jeeejeee/chatglm3-text2sql-spider")
218
219
220
221


@pytest.fixture(scope="session")
def baichuan_lora_files():
222
223
    # return snapshot_download(repo_id="jeeejeee/baichuan7b-text2sql-spider")
    return os.path.join(models_path_prefix, "jeeejeee/baichuan7b-text2sql-spider")
224
225


226
227
228
@pytest.fixture(scope="session")
def baichuan_zero_lora_files():
    # all the lora_B weights are initialized to zero.
229
230
    # return snapshot_download(repo_id="jeeejeee/baichuan7b-zero-init")
    return os.path.join(models_path_prefix, "jeeejeee/baichuan7b-zero-init")
231
232


233
234
235
236
237
@pytest.fixture(scope="session")
def baichuan_regex_lora_files():
    return snapshot_download(repo_id="jeeejeee/baichuan-7b-lora-zero-regex")


238
239
240
241
242
@pytest.fixture(scope="session")
def minicpmv_lora_files():
    return snapshot_download(repo_id="jeeejeee/minicpmv25-lora-pokemon")


243
244
245
246
247
@pytest.fixture(scope="session")
def qwen2vl_lora_files():
    return snapshot_download(repo_id="jeeejeee/qwen2-vl-lora-pokemon")


248
249
@pytest.fixture(scope="session")
def tinyllama_lora_files():
250
251
    # return snapshot_download(repo_id="jashing/tinyllama-colorist-lora")
    return os.path.join(models_path_prefix, "jashing/tinyllama-colorist-lora")
252
253


254
255
@pytest.fixture(scope="session")
def phi2_lora_files():
256
257
    # return snapshot_download(repo_id="isotr0py/phi-2-test-sql-lora")
    return os.path.join(models_path_prefix, "isotr0py/phi-2-test-sql-lora")
258

王敏's avatar
王敏 committed
259
260
261
262
263
@pytest.fixture(scope="session")
def qwen_lora_files():
    # return snapshot_download(repo_id="jeeejeee/chatglm3-text2sql-spider")
    return os.path.join(models_path_prefix, "customize/qwen-nl2dsl-lora")

264

265
266
@pytest.fixture(scope="session")
def long_context_lora_files_16k_1():
267
268
    # return snapshot_download(repo_id="SangBinCho/long_context_16k_testing_1")
    return os.path.join(models_path_prefix, "SangBinCho/long_context_16k_testing_1")
269
270
271
272


@pytest.fixture(scope="session")
def long_context_lora_files_16k_2():
273
274
    # return snapshot_download(repo_id="SangBinCho/long_context_16k_testing_2")
    return os.path.join(models_path_prefix, "SangBinCho/long_context_16k_testing_2")
275
276
277
278


@pytest.fixture(scope="session")
def long_context_lora_files_32k():
279
280
    # return snapshot_download(repo_id="SangBinCho/long_context_32k_testing")
    return os.path.join(models_path_prefix, "SangBinCho/long_context_32k_testing")
281
282
283
284
285
286


@pytest.fixture(scope="session")
def long_context_infos(long_context_lora_files_16k_1,
                       long_context_lora_files_16k_2,
                       long_context_lora_files_32k):
287
    cleanup_dist_env_and_memory(shutdown_ray=True)
288
    infos: Dict[int, ContextInfo] = {}
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
    for lora_checkpoint_info in LONG_LORA_INFOS:
        lora_id = lora_checkpoint_info["lora_id"]
        if lora_id == 1:
            lora = long_context_lora_files_16k_1
        elif lora_id == 2:
            lora = long_context_lora_files_16k_2
        elif lora_id == 3:
            lora = long_context_lora_files_32k
        else:
            raise AssertionError("Unknown lora id")
        infos[lora_id] = {
            "context_length": lora_checkpoint_info["context_length"],
            "lora": lora,
        }
    return infos


306
@pytest.fixture
307
def llama_2_7b_engine_extra_embeddings():
308
    cleanup_dist_env_and_memory(shutdown_ray=True)
309
310
    get_model_old = get_model

311
312
313
314
    def get_model_patched(**kwargs):
        kwargs["vllm_config"].lora_config = LoRAConfig(max_loras=4,
                                                       max_lora_rank=8)
        return get_model_old(**kwargs)
315
316

    with patch("vllm.worker.model_runner.get_model", get_model_patched):
317
        engine = vllm.LLM(os.path.join(models_path_prefix, "meta-llama/Llama-2-7b-hf"), enable_lora=False)
318
319
    yield engine.llm_engine
    del engine
320
    cleanup_dist_env_and_memory(shutdown_ray=True)
321
322
323


@pytest.fixture
324
def llama_2_7b_model_extra_embeddings(llama_2_7b_engine_extra_embeddings):
325
326
    yield (llama_2_7b_engine_extra_embeddings.model_executor.driver_worker.
           model_runner.model)