test_peft_helper.py 2.94 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
import json
import math
import shutil

import pytest

10
from vllm.config.lora import LoRAConfig
11
12
13
14
15
from vllm.lora.peft_helper import PEFTHelper

ERROR_CASES = [
    (
        "test_rank",
16
        {"r": 1024},
17
18
19
20
        "is greater than max_lora_rank",
    ),
    (
        "test_bias",
21
        {"bias": "all"},
22
23
        "Adapter bias cannot be used without bias_enabled",
    ),
24
    ("test_dora", {"use_dora": True}, "does not yet support DoRA"),
25
26
    (
        "test_modules_to_save",
27
        {"modules_to_save": ["lm_head"]},
28
29
30
31
32
        "only supports modules_to_save being None",
    ),
]


33
def test_peft_helper_pass(sql_lora_files, tmp_path):
34
35
36
    peft_helper = PEFTHelper.from_local_dir(
        sql_lora_files, max_position_embeddings=4096
    )
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
    lora_config = LoRAConfig(max_lora_rank=16, max_cpu_loras=3, max_loras=2)
    peft_helper.validate_legal(lora_config)
    assert peft_helper.r == 8
    assert peft_helper.lora_alpha == 16
    assert peft_helper.target_modules == [
        "q_proj",
        "v_proj",
        "k_proj",
        "o_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
        "embed_tokens",
        "lm_head",
    ]
    assert peft_helper.vllm_max_position_embeddings == 4096
53

54
55
56
    # test RSLoRA
    rslora_config = dict(use_rslora=True)
    test_dir = tmp_path / "test_rslora"
57
    shutil.copytree(sql_lora_files, test_dir)
58
59
60
61
62
63
64
65
66
67
68
69

    # Load and modify configuration
    config_path = test_dir / "adapter_config.json"
    with open(config_path) as f:
        adapter_config = json.load(f)
    # Apply configuration changes
    adapter_config.update(rslora_config)

    # Save modified configuration
    with open(config_path, "w") as f:
        json.dump(adapter_config, f)

70
    peft_helper = PEFTHelper.from_local_dir(test_dir, max_position_embeddings=4096)
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
    peft_helper.validate_legal(lora_config)
    scaling = peft_helper.lora_alpha / math.sqrt(peft_helper.r)
    assert abs(peft_helper.vllm_lora_scaling_factor - scaling) < 1e-3


@pytest.mark.parametrize("test_name,config_change,expected_error", ERROR_CASES)
def test_peft_helper_error(
    sql_lora_files,
    tmp_path,
    test_name: str,
    config_change: dict,
    expected_error: str,
):
    test_dir = tmp_path / test_name
    shutil.copytree(sql_lora_files, test_dir)

    # Load and modify configuration
    config_path = test_dir / "adapter_config.json"
    with open(config_path) as f:
        adapter_config = json.load(f)
    # Apply configuration changes
    adapter_config.update(config_change)

    # Save modified configuration
    with open(config_path, "w") as f:
        json.dump(adapter_config, f)
    lora_config = LoRAConfig(max_lora_rank=16, max_cpu_loras=3, max_loras=2)
    # Test loading the adapter
    with pytest.raises(ValueError, match=expected_error):
        PEFTHelper.from_local_dir(
101
102
            test_dir, max_position_embeddings=4096
        ).validate_legal(lora_config)