test_worker.py 2.45 KB
Newer Older
1
2
3
4
5
import os
import random
import tempfile
from unittest.mock import patch

6
7
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
                         ModelConfig, ParallelConfig, SchedulerConfig)
8
9
10
11
12
13
14
15
16
17
from vllm.lora.models import LoRAMapping
from vllm.lora.request import LoRARequest
from vllm.worker.worker import Worker


@patch.dict(os.environ, {"RANK": "0"})
def test_worker_apply_lora(sql_lora_files):
    worker = Worker(
        model_config=ModelConfig(
            "meta-llama/Llama-2-7b-hf",
18
19
            task="auto",
            tokenizer="meta-llama/Llama-2-7b-hf",
20
21
22
23
24
25
            tokenizer_mode="auto",
            trust_remote_code=False,
            seed=0,
            dtype="float16",
            revision=None,
        ),
26
27
28
29
        load_config=LoadConfig(
            download_dir=None,
            load_format="dummy",
        ),
30
        parallel_config=ParallelConfig(1, 1, False),
31
        scheduler_config=SchedulerConfig("generate", 32, 32, 32),
32
        device_config=DeviceConfig("cuda"),
33
34
35
36
        cache_config=CacheConfig(block_size=16,
                                 gpu_memory_utilization=1.,
                                 swap_space=0,
                                 cache_dtype="auto"),
37
38
39
40
41
42
        local_rank=0,
        rank=0,
        lora_config=LoRAConfig(max_lora_rank=8, max_cpu_loras=32,
                               max_loras=32),
        distributed_init_method=f"file://{tempfile.mkstemp()[1]}",
    )
43
    worker.init_device()
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
    worker.load_model()

    worker.model_runner.set_active_loras([], LoRAMapping([], []))
    assert worker.list_loras() == set()

    n_loras = 32
    lora_requests = [
        LoRARequest(str(i + 1), i + 1, sql_lora_files) for i in range(n_loras)
    ]

    worker.model_runner.set_active_loras(lora_requests, LoRAMapping([], []))
    assert worker.list_loras() == {
        lora_request.lora_int_id
        for lora_request in lora_requests
    }

    for i in range(32):
        random.seed(i)
        iter_lora_requests = random.choices(lora_requests,
                                            k=random.randint(1, n_loras))
        random.shuffle(iter_lora_requests)
        iter_lora_requests = iter_lora_requests[:-random.randint(0, n_loras)]
        worker.model_runner.set_active_loras(iter_lora_requests,
                                             LoRAMapping([], []))
        assert worker.list_loras().issuperset(
            {lora_request.lora_int_id
             for lora_request in iter_lora_requests})