test_worker.py 3.14 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
5
import os
import random
import tempfile
6
from typing import Union
7
8
from unittest.mock import patch

9
import vllm.envs as envs
10
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
11
12
                         ModelConfig, ParallelConfig, SchedulerConfig,
                         VllmConfig)
13
14
from vllm.lora.models import LoRAMapping
from vllm.lora.request import LoRARequest
15
from vllm.v1.worker.gpu_worker import Worker as V1Worker
16
17
18
19
20
from vllm.worker.worker import Worker


@patch.dict(os.environ, {"RANK": "0"})
def test_worker_apply_lora(sql_lora_files):
21
22
23
24
25
26
27
28
29
30
31
32
33
34

    def set_active_loras(worker: Union[Worker, V1Worker],
                         lora_requests: list[LoRARequest]):
        lora_mapping = LoRAMapping([], [])
        if isinstance(worker, Worker):
            # v0 case
            worker.model_runner.set_active_loras(lora_requests, lora_mapping)
        else:
            # v1 case
            worker.model_runner.lora_manager.set_active_adapters(
                lora_requests, lora_mapping)

    worker_cls = V1Worker if envs.VLLM_USE_V1 else Worker

35
    vllm_config = VllmConfig(
36
37
        model_config=ModelConfig(
            "meta-llama/Llama-2-7b-hf",
38
39
            task="auto",
            tokenizer="meta-llama/Llama-2-7b-hf",
40
41
42
43
44
            tokenizer_mode="auto",
            trust_remote_code=False,
            seed=0,
            dtype="float16",
            revision=None,
45
            enforce_eager=True,
46
        ),
47
48
49
50
        load_config=LoadConfig(
            download_dir=None,
            load_format="dummy",
        ),
Jee Jee Li's avatar
Jee Jee Li committed
51
52
53
54
55
        parallel_config=ParallelConfig(
            pipeline_parallel_size=1,
            tensor_parallel_size=1,
            data_parallel_size=1,
        ),
56
        scheduler_config=SchedulerConfig("generate", 32, 32, 32),
57
        device_config=DeviceConfig("cuda"),
Jee Jee Li's avatar
Jee Jee Li committed
58
59
60
61
62
63
        cache_config=CacheConfig(
            block_size=16,
            gpu_memory_utilization=1.0,
            swap_space=0,
            cache_dtype="auto",
        ),
64
65
        lora_config=LoRAConfig(max_lora_rank=8, max_cpu_loras=32,
                               max_loras=32),
66
    )
67
    worker = worker_cls(
68
69
70
        vllm_config=vllm_config,
        local_rank=0,
        rank=0,
71
72
        distributed_init_method=f"file://{tempfile.mkstemp()[1]}",
    )
73

74
    worker.init_device()
75
76
    worker.load_model()

77
    set_active_loras(worker, [])
78
79
80
81
82
83
84
    assert worker.list_loras() == set()

    n_loras = 32
    lora_requests = [
        LoRARequest(str(i + 1), i + 1, sql_lora_files) for i in range(n_loras)
    ]

85
    set_active_loras(worker, lora_requests)
86
87
88
89
90
91
92
93
94
95
96
    assert worker.list_loras() == {
        lora_request.lora_int_id
        for lora_request in lora_requests
    }

    for i in range(32):
        random.seed(i)
        iter_lora_requests = random.choices(lora_requests,
                                            k=random.randint(1, n_loras))
        random.shuffle(iter_lora_requests)
        iter_lora_requests = iter_lora_requests[:-random.randint(0, n_loras)]
97
        set_active_loras(worker, lora_requests)
98
99
100
        assert worker.list_loras().issuperset(
            {lora_request.lora_int_id
             for lora_request in iter_lora_requests})