peft_helper.py 4.26 KB
Newer Older
1
2
# Adapted from: https://github.com/huggingface/peft/blob/main/src/peft/tuners/lora/config.py

3
import json
4
import math
5
import os
6
from dataclasses import MISSING, dataclass, field, fields
7
from typing import List, Literal, Optional, Union
8

9
from vllm.config import LoRAConfig
10
11
12
from vllm.logger import init_logger

logger = init_logger(__name__)
13

14
15
16

@dataclass
class PEFTHelper:
17
18
19
20
21
22
    """ 
    A helper class for PEFT configurations, specifically designed for LoRA.
    This class handles configuration validation, compatibility checks for 
    various LoRA implementations.
    """

23
24
25
26
27
28
29
    # Required fields
    r: int
    lora_alpha: int
    target_modules: Union[list[str], str]

    bias: Literal["none", "all", "lora_only"] = field(default="none")
    modules_to_save: Optional[list[str]] = field(default=None)
30
    # True to use Rank-Stabilized LoRA (rsLoRA, see: https://arxiv.org/abs/2312.03732)
31
    use_rslora: bool = field(default=False)
32
    # True to use Weight-Decomposed Low-Rank Adaptation (DoRA, see: https://arxiv.org/abs/2402.09353)
33
    use_dora: bool = field(default=False)
34
    # long context lora field
35
36
    context_length: int = field(default=0)
    # Extra vllm field, start with 'vllm_' to avoid conflict
37
    vllm_lora_scaling_factor: float = field(default=1.0)
38
    vllm_max_position_embeddings: Optional[int] = field(default=False)
39
    vllm_long_context_scaling_factor: Optional[float] = field(default=None)
40

41
42
43
44
    def _validate_features(self) -> List[str]:
        """
        Check if there are any unsupported Lora features.
        """
45
46
47
48
49
        error_msg = []
        if self.modules_to_save:
            error_msg.append("vLLM only supports modules_to_save being None.")
        if self.use_dora:
            error_msg.append("vLLM does not yet support DoRA.")
50
        return error_msg
51
52

    def __post_init__(self):
53
        if self.use_rslora:
54
            logger.info_once("Loading LoRA weights trained with rsLoRA.")
55
56
57
            self.vllm_lora_scaling_factor = self.lora_alpha / math.sqrt(self.r)
        else:
            self.vllm_lora_scaling_factor = self.lora_alpha / self.r
58
59
60
        if self.context_length:
            if self.vllm_max_position_embeddings is None:
                self.vllm_max_position_embeddings = self.context_length
61
            self.vllm_long_context_scaling_factor = float(
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
                math.ceil(self.context_length /
                          self.vllm_max_position_embeddings))

    @classmethod
    def from_dict(cls, config_dict: dict) -> "PEFTHelper":
        # Get all field information from the class
        class_fields = {f.name: f for f in fields(cls)}
        # Check for required fields
        required_fields = {
            name
            for name, f in class_fields.items()
            if f.default is MISSING and f.default_factory is MISSING
        }

        # Identify any missing required fields
        missing_fields = required_fields - set(config_dict.keys())
        if missing_fields:
            raise ValueError(
                f"Missing required configuration fields: {missing_fields}")

        # Filter out fields that aren't defined in the class
        filtered_dict = {
            k: v
            for k, v in config_dict.items() if k in class_fields
        }
        return cls(**filtered_dict)
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113

    @classmethod
    def from_local_dir(cls, lora_path: str,
                       max_position_embeddings: Optional[int]) -> "PEFTHelper":
        lora_config_path = os.path.join(lora_path, "adapter_config.json")

        with open(lora_config_path) as f:
            config = json.load(f)
        config["vllm_max_position_embeddings"] = max_position_embeddings
        return cls.from_dict(config)

    def validate_legal(self, lora_config: LoRAConfig) -> None:
        """
        Validates the LoRA configuration settings against application 
        constraints and requirements.
        """
        error_msg = self._validate_features()
        if self.r > lora_config.max_lora_rank:
            error_msg.append(
                f"LoRA rank {self.r} is greater than max_lora_rank"
                f" {lora_config.max_lora_rank}.")
        if self.bias != "none" and not lora_config.bias_enabled:
            error_msg.append(
                "Adapter bias cannot be used without bias_enabled.")
        if error_msg:
            raise ValueError(f"{' '.join(error_msg)}")