peft_helper.py 5.14 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
# Adapted from: https://github.com/huggingface/peft/blob/main/src/peft/tuners/lora/config.py

5
import json
6
import math
7
import os
8
from dataclasses import MISSING, dataclass, field, fields
9
from typing import Literal, Optional, Union
10

11
from vllm.config import LoRAConfig
12
from vllm.logger import init_logger
13
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
14
15

logger = init_logger(__name__)
16

17
18
19

@dataclass
class PEFTHelper:
20
21
22
23
24
25
    """ 
    A helper class for PEFT configurations, specifically designed for LoRA.
    This class handles configuration validation, compatibility checks for 
    various LoRA implementations.
    """

26
27
28
29
30
31
32
    # Required fields
    r: int
    lora_alpha: int
    target_modules: Union[list[str], str]

    bias: Literal["none", "all", "lora_only"] = field(default="none")
    modules_to_save: Optional[list[str]] = field(default=None)
33
    # True to use Rank-Stabilized LoRA (rsLoRA, see: https://arxiv.org/abs/2312.03732)
34
    use_rslora: bool = field(default=False)
35
    # True to use Weight-Decomposed Low-Rank Adaptation (DoRA, see: https://arxiv.org/abs/2402.09353)
36
    use_dora: bool = field(default=False)
37
    # long context lora field
38
39
    context_length: int = field(default=0)
    # Extra vllm field, start with 'vllm_' to avoid conflict
40
    vllm_lora_scaling_factor: float = field(default=1.0)
41
    vllm_max_position_embeddings: Optional[int] = field(default=False)
42
    vllm_long_context_scaling_factor: Optional[float] = field(default=None)
43

44
    def _validate_features(self) -> list[str]:
45
        """
46
        Check if there are any unsupported LoRA features.
47
        """
48
49
50
51
52
        error_msg = []
        if self.modules_to_save:
            error_msg.append("vLLM only supports modules_to_save being None.")
        if self.use_dora:
            error_msg.append("vLLM does not yet support DoRA.")
53
        return error_msg
54
55

    def __post_init__(self):
56
        if self.use_rslora:
57
            logger.info_once("Loading LoRA weights trained with rsLoRA.")
58
59
60
            self.vllm_lora_scaling_factor = self.lora_alpha / math.sqrt(self.r)
        else:
            self.vllm_lora_scaling_factor = self.lora_alpha / self.r
61
62
63
        if self.context_length:
            if self.vllm_max_position_embeddings is None:
                self.vllm_max_position_embeddings = self.context_length
64
            self.vllm_long_context_scaling_factor = float(
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
                math.ceil(self.context_length /
                          self.vllm_max_position_embeddings))

    @classmethod
    def from_dict(cls, config_dict: dict) -> "PEFTHelper":
        # Get all field information from the class
        class_fields = {f.name: f for f in fields(cls)}
        # Check for required fields
        required_fields = {
            name
            for name, f in class_fields.items()
            if f.default is MISSING and f.default_factory is MISSING
        }

        # Identify any missing required fields
        missing_fields = required_fields - set(config_dict.keys())
        if missing_fields:
            raise ValueError(
                f"Missing required configuration fields: {missing_fields}")

        # Filter out fields that aren't defined in the class
        filtered_dict = {
            k: v
            for k, v in config_dict.items() if k in class_fields
        }
        return cls(**filtered_dict)
91
92

    @classmethod
93
94
95
96
97
    def from_local_dir(
            cls,
            lora_path: str,
            max_position_embeddings: Optional[int],
            tensorizer_config_dict: Optional[dict] = None) -> "PEFTHelper":
98
99
        lora_config_path = os.path.join(lora_path, "adapter_config.json")

100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
        if tensorizer_config_dict:
            tensorizer_config = TensorizerConfig(**tensorizer_config_dict)
            tensorizer_args = tensorizer_config._construct_tensorizer_args()
            from tensorizer.stream_io import open_stream
            lora_config_path = os.path.join(tensorizer_config.lora_dir,
                                            "adapter_config.json")
            with open_stream(lora_config_path,
                             mode="rb",
                             **tensorizer_args.stream_params) as f:
                config = json.load(f)

            logger.info("Successfully deserialized LoRA config from %s",
                        tensorizer_config.lora_dir)

        else:
            with open(lora_config_path) as f:
                config = json.load(f)

118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
        config["vllm_max_position_embeddings"] = max_position_embeddings
        return cls.from_dict(config)

    def validate_legal(self, lora_config: LoRAConfig) -> None:
        """
        Validates the LoRA configuration settings against application 
        constraints and requirements.
        """
        error_msg = self._validate_features()
        if self.r > lora_config.max_lora_rank:
            error_msg.append(
                f"LoRA rank {self.r} is greater than max_lora_rank"
                f" {lora_config.max_lora_rank}.")
        if self.bias != "none" and not lora_config.bias_enabled:
            error_msg.append(
                "Adapter bias cannot be used without bias_enabled.")
        if error_msg:
            raise ValueError(f"{' '.join(error_msg)}")