"docs/source/deployment/frameworks/triton.md" did not exist on "d93d2d74fd807a091add17c2065ee8869339f76a"
test_worker.py 2.72 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
import os
import random
import tempfile
from unittest.mock import patch

9
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
10
11
                         ParallelConfig, SchedulerConfig, VllmConfig)
from vllm.config.load import LoadConfig
12
from vllm.config.lora import LoRAConfig
13
14
from vllm.lora.models import LoRAMapping
from vllm.lora.request import LoRARequest
15
from vllm.v1.worker.gpu_worker import Worker
16

Jee Jee Li's avatar
Jee Jee Li committed
17
18
NUM_LORAS = 16

19
20
21

@patch.dict(os.environ, {"RANK": "0"})
def test_worker_apply_lora(sql_lora_files):
22

23
    def set_active_loras(worker: Worker, lora_requests: list[LoRARequest]):
24
25
        lora_mapping = LoRAMapping([], [])

26
27
        worker.model_runner.lora_manager.set_active_adapters(
            lora_requests, lora_mapping)
28

29
    vllm_config = VllmConfig(
30
31
32
33
        model_config=ModelConfig(
            "meta-llama/Llama-2-7b-hf",
            seed=0,
            dtype="float16",
34
            enforce_eager=True,
35
        ),
36
37
38
39
        load_config=LoadConfig(
            download_dir=None,
            load_format="dummy",
        ),
Jee Jee Li's avatar
Jee Jee Li committed
40
41
42
43
44
        parallel_config=ParallelConfig(
            pipeline_parallel_size=1,
            tensor_parallel_size=1,
            data_parallel_size=1,
        ),
45
        scheduler_config=SchedulerConfig("generate", 32, 32, 32),
46
        device_config=DeviceConfig("cuda"),
Jee Jee Li's avatar
Jee Jee Li committed
47
48
49
50
51
        cache_config=CacheConfig(
            block_size=16,
            swap_space=0,
            cache_dtype="auto",
        ),
Jee Jee Li's avatar
Jee Jee Li committed
52
53
54
        lora_config=LoRAConfig(max_lora_rank=8,
                               max_cpu_loras=NUM_LORAS,
                               max_loras=NUM_LORAS),
55
    )
56
    worker = Worker(
57
58
59
        vllm_config=vllm_config,
        local_rank=0,
        rank=0,
60
61
        distributed_init_method=f"file://{tempfile.mkstemp()[1]}",
    )
62

63
    worker.init_device()
64
65
    worker.load_model()

66
    set_active_loras(worker, [])
67
68
69
    assert worker.list_loras() == set()

    lora_requests = [
Jee Jee Li's avatar
Jee Jee Li committed
70
71
        LoRARequest(str(i + 1), i + 1, sql_lora_files)
        for i in range(NUM_LORAS)
72
73
    ]

74
    set_active_loras(worker, lora_requests)
75
76
77
78
79
    assert worker.list_loras() == {
        lora_request.lora_int_id
        for lora_request in lora_requests
    }

Jee Jee Li's avatar
Jee Jee Li committed
80
    for i in range(NUM_LORAS):
81
82
        random.seed(i)
        iter_lora_requests = random.choices(lora_requests,
Jee Jee Li's avatar
Jee Jee Li committed
83
                                            k=random.randint(1, NUM_LORAS))
84
        random.shuffle(iter_lora_requests)
Jee Jee Li's avatar
Jee Jee Li committed
85
        iter_lora_requests = iter_lora_requests[:-random.randint(0, NUM_LORAS)]
86
        set_active_loras(worker, lora_requests)
87
88
89
        assert worker.list_loras().issuperset(
            {lora_request.lora_int_id
             for lora_request in iter_lora_requests})