peft_helper.py 2.91 KB
Newer Older
1
2
3
4
5
6
# Adapted from: https://github.com/huggingface/peft/blob/main/src/peft/tuners/lora/config.py

import math
from dataclasses import MISSING, dataclass, field, fields
from typing import Literal, Optional, Union

7
8
from vllm.utils import print_info_once

9
10
11
12
13
14
15
16
17
18

@dataclass
class PEFTHelper:
    # Required fields
    r: int
    lora_alpha: int
    target_modules: Union[list[str], str]

    bias: Literal["none", "all", "lora_only"] = field(default="none")
    modules_to_save: Optional[list[str]] = field(default=None)
19
    # True to use Rank-Stabilized LoRA (rsLoRA, see: https://arxiv.org/abs/2312.03732)
20
    use_rslora: bool = field(default=False)
21
    # True to use Weight-Decomposed Low-Rank Adaptation (DoRA, see: https://arxiv.org/abs/2402.09353)
22
    use_dora: bool = field(default=False)
23
    # long context lora field
24
25
    context_length: int = field(default=0)
    # Extra vllm field, start with 'vllm_' to avoid conflict
26
    vllm_lora_scaling_factor: float = field(default=1.0)
27
    vllm_max_position_embeddings: Optional[int] = field(default=False)
28
    vllm_long_context_scaling_factor: Optional[float] = field(default=None)
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43

    def _validate_features(self):
        error_msg = []

        if self.modules_to_save:
            error_msg.append("vLLM only supports modules_to_save being None.")

        if self.use_dora:
            error_msg.append("vLLM does not yet support DoRA.")

        if error_msg:
            raise ValueError(f"{', '.join(error_msg)}")

    def __post_init__(self):
        self._validate_features()
44
45
46
47
48
        if self.use_rslora:
            print_info_once("Loading LoRA weights trained with rsLoRA.")
            self.vllm_lora_scaling_factor = self.lora_alpha / math.sqrt(self.r)
        else:
            self.vllm_lora_scaling_factor = self.lora_alpha / self.r
49
50
51
        if self.context_length:
            if self.vllm_max_position_embeddings is None:
                self.vllm_max_position_embeddings = self.context_length
52
            self.vllm_long_context_scaling_factor = float(
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
                math.ceil(self.context_length /
                          self.vllm_max_position_embeddings))

    @classmethod
    def from_dict(cls, config_dict: dict) -> "PEFTHelper":
        # Get all field information from the class
        class_fields = {f.name: f for f in fields(cls)}
        # Check for required fields
        required_fields = {
            name
            for name, f in class_fields.items()
            if f.default is MISSING and f.default_factory is MISSING
        }

        # Identify any missing required fields
        missing_fields = required_fields - set(config_dict.keys())
        if missing_fields:
            raise ValueError(
                f"Missing required configuration fields: {missing_fields}")

        # Filter out fields that aren't defined in the class
        filtered_dict = {
            k: v
            for k, v in config_dict.items() if k in class_fields
        }
        return cls(**filtered_dict)