lora_weights.py 4.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
from collections.abc import Sequence as GenericSequence
from typing import Optional
6
7

import torch
8
import torch.types
9

10
from vllm.lora.peft_helper import PEFTHelper
11
from vllm.utils.platform_utils import is_pin_memory_available
12
13
14
15
16
17
18
19
20
21
22
23


class LoRALayerWeights:
    """LoRA weights for a layer composed of two low rank matrixes."""

    def __init__(
        self,
        module_name: str,
        rank: int,
        lora_alpha: int,
        lora_a: torch.Tensor,
        lora_b: torch.Tensor,
24
        scaling: float | None = None,
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
    ) -> None:
        self.module_name = module_name
        self.rank = rank
        self.lora_alpha = lora_alpha
        self.lora_a = lora_a
        self.lora_b = lora_b

        if scaling is None:
            self.scaling = self.lora_alpha / self.rank
        else:
            self.scaling = scaling

    def optimize(self) -> "LoRALayerWeights":
        """Optimize the LoRA by merging the scaling into lora_b."""
        if self.scaling == 1:
40
            return self
41
42
43
44
45
46
        self.lora_b *= self.scaling
        self.scaling = 1
        return self

    @property
    def input_dim(self) -> int:
47
        return self.lora_a.shape[1]
48
49
50

    @property
    def output_dim(self) -> int:
51
        return self.lora_b.shape[0]
52
53
54
55
56

    @property
    def is_packed(self) -> bool:
        return False

57
58
59
60
61
62
    @classmethod
    def from_config(
        cls,
        module_name: str,
        peft_helper: PEFTHelper,
    ) -> "LoRALayerWeights":
63
        # lora_a and lora_b are set to None for config-based construction
64
65
66
67
68
69
70
71
        return cls(
            module_name,
            peft_helper.r,
            peft_helper.lora_alpha,
            None,
            None,
            peft_helper.vllm_lora_scaling_factor,
        )
72

73
74
    @classmethod
    def create_dummy_lora_weights(
75
76
77
78
79
80
81
82
        cls,
        module_name: str,
        input_dim: int,
        output_dim: int,
        rank: int,
        dtype: torch.dtype,
        device: torch.types.Device,
    ) -> "LoRALayerWeights":
83
        pin_memory = str(device) == "cpu" and is_pin_memory_available()
84
85
86
87
88
89
        lora_a = torch.zeros(
            [rank, input_dim], dtype=dtype, device=device, pin_memory=pin_memory
        )
        lora_b = torch.zeros(
            [output_dim, rank], dtype=dtype, device=device, pin_memory=pin_memory
        )
90

91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
        return cls(
            module_name,
            rank=rank,
            lora_alpha=1,
            lora_a=lora_a,
            lora_b=lora_b,
        )


class PackedLoRALayerWeights(LoRALayerWeights):
    """LoRA used for packed layers (eg. qkv_proj)."""

    def __init__(
        self,
        module_name: str,
        rank: int,
107
108
109
110
        lora_alphas: list[int | None],
        lora_a: list[torch.Tensor | None],
        lora_b: list[torch.Tensor | None],
        scaling: list[float] | None = None,
111
112
113
114
115
116
117
    ) -> None:
        super().__init__(
            module_name=module_name,
            rank=rank,
            lora_alpha=0,
            lora_a=lora_a,
            lora_b=lora_b,
118
            scaling=scaling,  # type: ignore
119
120
121
        )
        self.lora_alphas = lora_alphas
        if scaling is None:
122
123
124
            self.scaling = [  # type: ignore
                lora_alpha / self.rank  # type: ignore # noqa
                for lora_alpha in self.lora_alphas
125
126
127
            ]

    @classmethod
128
    def pack(
129
        cls, loras: GenericSequence[Optional["LoRALayerWeights"]]
130
    ) -> "PackedLoRALayerWeights":
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
        """Pack a list of LoRAs into a single LoRA.

        If LoRA is None, it signifies that the submodule does not have a LoRA.
        """
        first_lora = next(lora for lora in loras if lora is not None)
        for lora in loras:
            if lora is None:
                continue
            lora.optimize()
        rank = first_lora.rank
        module_name = first_lora.module_name
        obj = cls(
            module_name,
            rank,
            [lora.lora_alpha if lora is not None else None for lora in loras],
            [lora.lora_a if lora is not None else None for lora in loras],
            [lora.lora_b if lora is not None else None for lora in loras],
148
149
150
            scaling=[
                1 if lora is not None else None  # type: ignore
                for lora in loras
151
152
            ],
        )
153
154
155
156
157
        return obj

    def optimize(self) -> "PackedLoRALayerWeights":
        """Optimize the LoRA by merging the scaling into lora_b."""
        for i in range(len(self.lora_b)):
158
            if self.scaling[i] == 1 or self.lora_b[i] is None:  # type: ignore
159
                continue
160
161
            self.lora_b[i] *= self.scaling[i]  # type: ignore
            self.scaling[i] = 1  # type: ignore
162
163
164
165
166
167
168
169
170
171
172
173
174
        return self

    @property
    def input_dim(self) -> int:
        raise NotImplementedError()

    @property
    def output_dim(self) -> int:
        raise NotImplementedError()

    @property
    def is_packed(self) -> bool:
        return True