lora_weights.py 7.69 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from collections.abc import Sequence as GenericSequence
5
6

import torch
7
import torch.types
8

9
from vllm.lora.peft_helper import PEFTHelper
10
from vllm.utils.platform_utils import is_pin_memory_available
11
12
13
14
15
16
17
18
19
20
21
22


class LoRALayerWeights:
    """LoRA weights for a layer composed of two low rank matrixes."""

    def __init__(
        self,
        module_name: str,
        rank: int,
        lora_alpha: int,
        lora_a: torch.Tensor,
        lora_b: torch.Tensor,
23
        scaling: float | None = None,
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
    ) -> None:
        self.module_name = module_name
        self.rank = rank
        self.lora_alpha = lora_alpha
        self.lora_a = lora_a
        self.lora_b = lora_b

        if scaling is None:
            self.scaling = self.lora_alpha / self.rank
        else:
            self.scaling = scaling

    def optimize(self) -> "LoRALayerWeights":
        """Optimize the LoRA by merging the scaling into lora_b."""
        if self.scaling == 1:
39
            return self
40
41
42
43
44
45
        self.lora_b *= self.scaling
        self.scaling = 1
        return self

    @property
    def input_dim(self) -> int:
46
        return self.lora_a.shape[1]
47
48
49

    @property
    def output_dim(self) -> int:
50
        return self.lora_b.shape[0]
51
52
53
54
55

    @property
    def is_packed(self) -> bool:
        return False

56
57
58
59
60
61
    @classmethod
    def from_config(
        cls,
        module_name: str,
        peft_helper: PEFTHelper,
    ) -> "LoRALayerWeights":
62
        # lora_a and lora_b are set to None for config-based construction
63
64
65
66
67
68
69
70
        return cls(
            module_name,
            peft_helper.r,
            peft_helper.lora_alpha,
            None,
            None,
            peft_helper.vllm_lora_scaling_factor,
        )
71

72
73
    @classmethod
    def create_dummy_lora_weights(
74
75
76
77
78
79
80
81
        cls,
        module_name: str,
        input_dim: int,
        output_dim: int,
        rank: int,
        dtype: torch.dtype,
        device: torch.types.Device,
    ) -> "LoRALayerWeights":
82
        pin_memory = str(device) == "cpu" and is_pin_memory_available()
83
84
85
86
87
88
        lora_a = torch.zeros(
            [rank, input_dim], dtype=dtype, device=device, pin_memory=pin_memory
        )
        lora_b = torch.zeros(
            [output_dim, rank], dtype=dtype, device=device, pin_memory=pin_memory
        )
89

90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
        return cls(
            module_name,
            rank=rank,
            lora_alpha=1,
            lora_a=lora_a,
            lora_b=lora_b,
        )


class PackedLoRALayerWeights(LoRALayerWeights):
    """LoRA used for packed layers (eg. qkv_proj)."""

    def __init__(
        self,
        module_name: str,
        rank: int,
106
107
108
109
        lora_alphas: list[int | None],
        lora_a: list[torch.Tensor | None],
        lora_b: list[torch.Tensor | None],
        scaling: list[float] | None = None,
110
111
112
113
114
115
116
    ) -> None:
        super().__init__(
            module_name=module_name,
            rank=rank,
            lora_alpha=0,
            lora_a=lora_a,
            lora_b=lora_b,
117
            scaling=scaling,  # type: ignore
118
119
120
        )
        self.lora_alphas = lora_alphas
        if scaling is None:
121
122
123
            self.scaling = [  # type: ignore
                lora_alpha / self.rank  # type: ignore # noqa
                for lora_alpha in self.lora_alphas
124
125
126
            ]

    @classmethod
127
    def pack(
128
        cls, loras: GenericSequence["LoRALayerWeights | None"]
129
    ) -> "PackedLoRALayerWeights":
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
        """Pack a list of LoRAs into a single LoRA.

        If LoRA is None, it signifies that the submodule does not have a LoRA.
        """
        first_lora = next(lora for lora in loras if lora is not None)
        for lora in loras:
            if lora is None:
                continue
            lora.optimize()
        rank = first_lora.rank
        module_name = first_lora.module_name
        obj = cls(
            module_name,
            rank,
            [lora.lora_alpha if lora is not None else None for lora in loras],
            [lora.lora_a if lora is not None else None for lora in loras],
            [lora.lora_b if lora is not None else None for lora in loras],
147
148
149
            scaling=[
                1 if lora is not None else None  # type: ignore
                for lora in loras
150
151
            ],
        )
152
153
        return obj

154
155
    @classmethod
    def pack_moe(
156
        cls,
157
        loras: GenericSequence["LoRALayerWeights | None"],
158
159
        module_name: str,
        is_non_gated_moe: bool = False,
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
    ) -> "PackedLoRALayerWeights":
        """Pack a list of LoRAs into a single LoRA.

        If LoRA is None, it signifies that the submodule does not have a LoRA.
        """

        first_lora = next(lora for lora in loras if lora is not None)
        assert first_lora is not None
        rank = first_lora.rank
        lora_alpha = first_lora.lora_alpha
        assert len(loras) % 3 == 0
        w1_lora_a_lst = []
        w2_lora_a_lst = []
        w3_lora_a_lst = []
        w1_lora_b_lst = []
        w2_lora_b_lst = []
        w3_lora_b_lst = []
        # TODO: Consider the case where some experts don't have LoRA added.
        for eid in range(len(loras) // 3):
            w1_lora = loras[eid * 3]
            w2_lora = loras[eid * 3 + 1]
            w3_lora = loras[eid * 3 + 2]
182
183
184
185
186
            # For non-gated MoE, w3 is not used, so we use w1's LoRA weights
            # This is determined by checking the expert mapping (get_expert_mapping)
            # which indicates when ckpt_up_proj_name is empty.
            if w3_lora is None and is_non_gated_moe:
                w3_lora = w1_lora
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
            assert w1_lora is not None
            assert w2_lora is not None
            assert w3_lora is not None

            w1_lora_a_lst.append(w1_lora.lora_a)
            w2_lora_a_lst.append(w2_lora.lora_a)
            w3_lora_a_lst.append(w3_lora.lora_a)

            w1_lora_b_lst.append(w1_lora.lora_b)
            w2_lora_b_lst.append(w2_lora.lora_b)
            w3_lora_b_lst.append(w3_lora.lora_b)

        w1_lora_a = torch.stack(w1_lora_a_lst, dim=0)  # (num_experts,rank,input_size)
        w2_lora_a = torch.stack(w2_lora_a_lst, dim=0)
        w1_lora_b = torch.stack(w1_lora_b_lst, dim=0)  # (num_experts,output_size,rank)
        w2_lora_b = torch.stack(w2_lora_b_lst, dim=0)
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218

        # All w1, w2, w3 have the same scaling factor.
        scaling = lora_alpha / rank
        last_scaling = scaling

        if is_non_gated_moe:
            # For non-gated MoE, reuse w1 tensors for w3 to avoid memory waste
            # w3_lora_a_lst and w3_lora_b_lst are not relevant in this case
            w3_lora_a = w1_lora_a
            w3_lora_b = w1_lora_b

            # For non-gated MoE, avoid double-scaling by setting w3's scaling to 1.
            last_scaling = 1.0
        else:
            w3_lora_a = torch.stack(w3_lora_a_lst, dim=0)
            w3_lora_b = torch.stack(w3_lora_b_lst, dim=0)
219
220
221
222
223
224
225

        obj = cls(
            module_name,
            rank,
            [lora_alpha, lora_alpha, lora_alpha],
            [w1_lora_a, w2_lora_a, w3_lora_a],
            [w1_lora_b, w2_lora_b, w3_lora_b],
226
            scaling=[scaling, scaling, last_scaling],
227
228
229
        )
        return obj

230
231
232
    def optimize(self) -> "PackedLoRALayerWeights":
        """Optimize the LoRA by merging the scaling into lora_b."""
        for i in range(len(self.lora_b)):
233
            if self.scaling[i] == 1 or self.lora_b[i] is None:  # type: ignore
234
                continue
235
236
            self.lora_b[i] *= self.scaling[i]  # type: ignore
            self.scaling[i] = 1  # type: ignore
237
238
239
240
241
242
243
244
245
246
247
248
249
        return self

    @property
    def input_dim(self) -> int:
        raise NotImplementedError()

    @property
    def output_dim(self) -> int:
        raise NotImplementedError()

    @property
    def is_packed(self) -> bool:
        return True