peft_helper.py 2.94 KB
Newer Older
1
2
3
4
5
6
# Adapted from: https://github.com/huggingface/peft/blob/main/src/peft/tuners/lora/config.py

import math
from dataclasses import MISSING, dataclass, field, fields
from typing import Literal, Optional, Union

7
8
9
from vllm.logger import init_logger

logger = init_logger(__name__)
10

11
12
13
14
15
16
17
18
19
20

@dataclass
class PEFTHelper:
    # Required fields
    r: int
    lora_alpha: int
    target_modules: Union[list[str], str]

    bias: Literal["none", "all", "lora_only"] = field(default="none")
    modules_to_save: Optional[list[str]] = field(default=None)
21
    # True to use Rank-Stabilized LoRA (rsLoRA, see: https://arxiv.org/abs/2312.03732)
22
    use_rslora: bool = field(default=False)
23
    # True to use Weight-Decomposed Low-Rank Adaptation (DoRA, see: https://arxiv.org/abs/2402.09353)
24
    use_dora: bool = field(default=False)
25
    # long context lora field
26
27
    context_length: int = field(default=0)
    # Extra vllm field, start with 'vllm_' to avoid conflict
28
    vllm_lora_scaling_factor: float = field(default=1.0)
29
    vllm_max_position_embeddings: Optional[int] = field(default=False)
30
    vllm_long_context_scaling_factor: Optional[float] = field(default=None)
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45

    def _validate_features(self):
        error_msg = []

        if self.modules_to_save:
            error_msg.append("vLLM only supports modules_to_save being None.")

        if self.use_dora:
            error_msg.append("vLLM does not yet support DoRA.")

        if error_msg:
            raise ValueError(f"{', '.join(error_msg)}")

    def __post_init__(self):
        self._validate_features()
46
        if self.use_rslora:
47
            logger.info_once("Loading LoRA weights trained with rsLoRA.")
48
49
50
            self.vllm_lora_scaling_factor = self.lora_alpha / math.sqrt(self.r)
        else:
            self.vllm_lora_scaling_factor = self.lora_alpha / self.r
51
52
53
        if self.context_length:
            if self.vllm_max_position_embeddings is None:
                self.vllm_max_position_embeddings = self.context_length
54
            self.vllm_long_context_scaling_factor = float(
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
                math.ceil(self.context_length /
                          self.vllm_max_position_embeddings))

    @classmethod
    def from_dict(cls, config_dict: dict) -> "PEFTHelper":
        # Get all field information from the class
        class_fields = {f.name: f for f in fields(cls)}
        # Check for required fields
        required_fields = {
            name
            for name, f in class_fields.items()
            if f.default is MISSING and f.default_factory is MISSING
        }

        # Identify any missing required fields
        missing_fields = required_fields - set(config_dict.keys())
        if missing_fields:
            raise ValueError(
                f"Missing required configuration fields: {missing_fields}")

        # Filter out fields that aren't defined in the class
        filtered_dict = {
            k: v
            for k, v in config_dict.items() if k in class_fields
        }
        return cls(**filtered_dict)