peft_helper.py 4.76 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
# Adapted from: https://github.com/huggingface/peft/blob/main/src/peft/tuners/lora/config.py

6
import json
7
import math
8
import os
9
from dataclasses import MISSING, dataclass, field, fields
10
from typing import Literal, Optional, Union
11

12
from vllm.config.lora import LoRAConfig
13
from vllm.logger import init_logger
14
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
15
16

logger = init_logger(__name__)
17

18
19
20

@dataclass
class PEFTHelper:
21
22
23
24
25
26
    """ 
    A helper class for PEFT configurations, specifically designed for LoRA.
    This class handles configuration validation, compatibility checks for 
    various LoRA implementations.
    """

27
28
29
30
31
32
33
    # Required fields
    r: int
    lora_alpha: int
    target_modules: Union[list[str], str]

    bias: Literal["none", "all", "lora_only"] = field(default="none")
    modules_to_save: Optional[list[str]] = field(default=None)
34
    # True to use Rank-Stabilized LoRA (rsLoRA, see: https://arxiv.org/abs/2312.03732)
35
    use_rslora: bool = field(default=False)
36
    # True to use Weight-Decomposed Low-Rank Adaptation (DoRA, see: https://arxiv.org/abs/2402.09353)
37
38
    use_dora: bool = field(default=False)
    # Extra vllm field, start with 'vllm_' to avoid conflict
39
    vllm_lora_scaling_factor: float = field(default=1.0)
40
41
    vllm_max_position_embeddings: Optional[int] = field(default=False)

42
    def _validate_features(self) -> list[str]:
43
        """
44
        Check if there are any unsupported LoRA features.
45
        """
46
47
48
49
50
        error_msg = []
        if self.modules_to_save:
            error_msg.append("vLLM only supports modules_to_save being None.")
        if self.use_dora:
            error_msg.append("vLLM does not yet support DoRA.")
51
        return error_msg
52
53

    def __post_init__(self):
54
        if self.use_rslora:
55
            logger.info_once("Loading LoRA weights trained with rsLoRA.")
56
57
58
            self.vllm_lora_scaling_factor = self.lora_alpha / math.sqrt(self.r)
        else:
            self.vllm_lora_scaling_factor = self.lora_alpha / self.r
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82

    @classmethod
    def from_dict(cls, config_dict: dict) -> "PEFTHelper":
        # Get all field information from the class
        class_fields = {f.name: f for f in fields(cls)}
        # Check for required fields
        required_fields = {
            name
            for name, f in class_fields.items()
            if f.default is MISSING and f.default_factory is MISSING
        }

        # Identify any missing required fields
        missing_fields = required_fields - set(config_dict.keys())
        if missing_fields:
            raise ValueError(
                f"Missing required configuration fields: {missing_fields}")

        # Filter out fields that aren't defined in the class
        filtered_dict = {
            k: v
            for k, v in config_dict.items() if k in class_fields
        }
        return cls(**filtered_dict)
83
84

    @classmethod
85
86
87
88
89
    def from_local_dir(
            cls,
            lora_path: str,
            max_position_embeddings: Optional[int],
            tensorizer_config_dict: Optional[dict] = None) -> "PEFTHelper":
90
91
        lora_config_path = os.path.join(lora_path, "adapter_config.json")

92
93
94
95
        if tensorizer_config_dict:
            tensorizer_config = TensorizerConfig(**tensorizer_config_dict)
            tensorizer_args = tensorizer_config._construct_tensorizer_args()
            from tensorizer.stream_io import open_stream
96
            lora_config_path = os.path.join(tensorizer_config.tensorizer_dir,
97
98
99
                                            "adapter_config.json")
            with open_stream(lora_config_path,
                             mode="rb",
100
                             **tensorizer_args.stream_kwargs) as f:
101
102
103
                config = json.load(f)

            logger.info("Successfully deserialized LoRA config from %s",
104
                        tensorizer_config.tensorizer_dir)
105
106
107
108
109

        else:
            with open(lora_config_path) as f:
                config = json.load(f)

110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
        config["vllm_max_position_embeddings"] = max_position_embeddings
        return cls.from_dict(config)

    def validate_legal(self, lora_config: LoRAConfig) -> None:
        """
        Validates the LoRA configuration settings against application 
        constraints and requirements.
        """
        error_msg = self._validate_features()
        if self.r > lora_config.max_lora_rank:
            error_msg.append(
                f"LoRA rank {self.r} is greater than max_lora_rank"
                f" {lora_config.max_lora_rank}.")
        if self.bias != "none" and not lora_config.bias_enabled:
            error_msg.append(
                "Adapter bias cannot be used without bias_enabled.")
        if error_msg:
            raise ValueError(f"{' '.join(error_msg)}")