test_worker.py 3.34 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
import os
import random
import tempfile
7
from typing import Union
8
9
from unittest.mock import patch

10
import vllm.envs as envs
11
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
12
13
                         ModelConfig, ParallelConfig, SchedulerConfig,
                         VllmConfig)
14
15
from vllm.lora.models import LoRAMapping
from vllm.lora.request import LoRARequest
16
from vllm.v1.worker.gpu_worker import Worker as V1Worker
17
from vllm.worker.worker import Worker
18
from ..utils import models_path_prefix
19

Jee Jee Li's avatar
Jee Jee Li committed
20
NUM_LORAS = 16
21
22


23
24
@patch.dict(os.environ, {"RANK": "0"})
def test_worker_apply_lora(sql_lora_files):
25
26
27
28
29
30
31
32
33
34
35
36
37
38

    def set_active_loras(worker: Union[Worker, V1Worker],
                         lora_requests: list[LoRARequest]):
        lora_mapping = LoRAMapping([], [])
        if isinstance(worker, Worker):
            # v0 case
            worker.model_runner.set_active_loras(lora_requests, lora_mapping)
        else:
            # v1 case
            worker.model_runner.lora_manager.set_active_adapters(
                lora_requests, lora_mapping)

    worker_cls = V1Worker if envs.VLLM_USE_V1 else Worker

39
    vllm_config = VllmConfig(
40
        model_config=ModelConfig(
41
            os.path.join(models_path_prefix, "meta-llama/Llama-2-7b-hf"),
42
            task="auto",
zhuwenwen's avatar
zhuwenwen committed
43
            tokenizer=os.path.join(models_path_prefix, "meta-llama/Llama-2-7b-hf"),
44
45
46
47
48
            tokenizer_mode="auto",
            trust_remote_code=False,
            seed=0,
            dtype="float16",
            revision=None,
49
            enforce_eager=True,
50
        ),
51
52
53
54
        load_config=LoadConfig(
            download_dir=None,
            load_format="dummy",
        ),
Jee Jee Li's avatar
Jee Jee Li committed
55
56
57
58
59
        parallel_config=ParallelConfig(
            pipeline_parallel_size=1,
            tensor_parallel_size=1,
            data_parallel_size=1,
        ),
60
        scheduler_config=SchedulerConfig("generate", 32, 32, 32),
61
        device_config=DeviceConfig("cuda"),
Jee Jee Li's avatar
Jee Jee Li committed
62
63
64
65
66
        cache_config=CacheConfig(
            block_size=16,
            swap_space=0,
            cache_dtype="auto",
        ),
Jee Jee Li's avatar
Jee Jee Li committed
67
68
69
        lora_config=LoRAConfig(max_lora_rank=8,
                               max_cpu_loras=NUM_LORAS,
                               max_loras=NUM_LORAS),
70
    )
71
    worker = worker_cls(
72
73
74
        vllm_config=vllm_config,
        local_rank=0,
        rank=0,
75
76
        distributed_init_method=f"file://{tempfile.mkstemp()[1]}",
    )
77

78
    worker.init_device()
79
80
    worker.load_model()

81
    set_active_loras(worker, [])
82
83
84
    assert worker.list_loras() == set()

    lora_requests = [
Jee Jee Li's avatar
Jee Jee Li committed
85
86
        LoRARequest(str(i + 1), i + 1, sql_lora_files)
        for i in range(NUM_LORAS)
87
88
    ]

89
    set_active_loras(worker, lora_requests)
90
91
92
93
94
    assert worker.list_loras() == {
        lora_request.lora_int_id
        for lora_request in lora_requests
    }

Jee Jee Li's avatar
Jee Jee Li committed
95
    for i in range(NUM_LORAS):
96
97
        random.seed(i)
        iter_lora_requests = random.choices(lora_requests,
Jee Jee Li's avatar
Jee Jee Li committed
98
                                            k=random.randint(1, NUM_LORAS))
99
        random.shuffle(iter_lora_requests)
Jee Jee Li's avatar
Jee Jee Li committed
100
        iter_lora_requests = iter_lora_requests[:-random.randint(0, NUM_LORAS)]
101
        set_active_loras(worker, lora_requests)
102
103
104
        assert worker.list_loras().issuperset(
            {lora_request.lora_int_id
             for lora_request in iter_lora_requests})