lora_weights.py 6.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
from collections.abc import Sequence as GenericSequence
from typing import Optional
6
7

import torch
8
import torch.types
9

10
from vllm.lora.peft_helper import PEFTHelper
11
from vllm.utils import is_pin_memory_available
12
13
14
15
16
17
18
19
20
21
22
23


class LoRALayerWeights:
    """LoRA weights for a layer composed of two low rank matrixes."""

    def __init__(
        self,
        module_name: str,
        rank: int,
        lora_alpha: int,
        lora_a: torch.Tensor,
        lora_b: torch.Tensor,
24
        bias: Optional[torch.Tensor] = None,
25
26
27
28
29
30
31
32
        embeddings_tensor: Optional[torch.Tensor] = None,
        scaling: Optional[float] = None,
    ) -> None:
        self.module_name = module_name
        self.rank = rank
        self.lora_alpha = lora_alpha
        self.lora_a = lora_a
        self.lora_b = lora_b
33
        self.bias = bias
34
35
36
37
38
39
40
41
42
43
        self.embeddings_tensor = embeddings_tensor

        if scaling is None:
            self.scaling = self.lora_alpha / self.rank
        else:
            self.scaling = scaling

    def optimize(self) -> "LoRALayerWeights":
        """Optimize the LoRA by merging the scaling into lora_b."""
        if self.scaling == 1:
44
            return self
45
46
47
48
49
50
        self.lora_b *= self.scaling
        self.scaling = 1
        return self

    @property
    def input_dim(self) -> int:
51
        return self.lora_a.shape[1]
52
53
54

    @property
    def output_dim(self) -> int:
55
        return self.lora_b.shape[0]
56
57
58
59
60
61
62

    @property
    def is_packed(self) -> bool:
        return False

    @property
    def extra_vocab_size(self) -> int:
63
64
65
        return (
            self.embeddings_tensor.shape[0] if self.embeddings_tensor is not None else 0
        )
66

67
68
69
70
71
72
73
    @classmethod
    def from_config(
        cls,
        module_name: str,
        peft_helper: PEFTHelper,
        embeddings_tensor: Optional[torch.Tensor] = None,
    ) -> "LoRALayerWeights":
74
75
76
77
78
79
80
81
82
83
        return cls(
            module_name,
            peft_helper.r,
            peft_helper.lora_alpha,
            None,
            None,
            None,
            embeddings_tensor,
            peft_helper.vllm_lora_scaling_factor,
        )
84

85
86
    @classmethod
    def create_dummy_lora_weights(
87
88
89
90
91
92
93
94
95
96
        cls,
        module_name: str,
        input_dim: int,
        output_dim: int,
        rank: int,
        dtype: torch.dtype,
        device: torch.types.Device,
        embeddings_tensor_dim: Optional[int] = None,
        bias_enabled: Optional[bool] = False,
    ) -> "LoRALayerWeights":
97
        pin_memory = str(device) == "cpu" and is_pin_memory_available()
98
99
100
101
102
103
        lora_a = torch.zeros(
            [rank, input_dim], dtype=dtype, device=device, pin_memory=pin_memory
        )
        lora_b = torch.zeros(
            [output_dim, rank], dtype=dtype, device=device, pin_memory=pin_memory
        )
104
        if bias_enabled:
105
106
107
            bias = torch.zeros(
                [output_dim], dtype=dtype, device=device, pin_memory=pin_memory
            )
108
109
110
        else:
            bias = None

111
112
113
114
115
116
117
118
119
120
121
        embeddings_tensor = (
            torch.rand(
                10,
                embeddings_tensor_dim,
                dtype=dtype,
                device=device,
                pin_memory=pin_memory,
            )
            if embeddings_tensor_dim
            else None
        )
122
123
124
125
126
127
        return cls(
            module_name,
            rank=rank,
            lora_alpha=1,
            lora_a=lora_a,
            lora_b=lora_b,
128
            bias=bias,
129
130
131
132
133
134
135
136
137
138
139
            embeddings_tensor=embeddings_tensor,
        )


class PackedLoRALayerWeights(LoRALayerWeights):
    """LoRA used for packed layers (eg. qkv_proj)."""

    def __init__(
        self,
        module_name: str,
        rank: int,
140
141
142
143
144
        lora_alphas: list[Optional[int]],
        lora_a: list[Optional[torch.Tensor]],
        lora_b: list[Optional[torch.Tensor]],
        bias: Optional[list[Optional[torch.Tensor]]] = None,
        scaling: Optional[list[float]] = None,
145
146
147
148
149
150
151
    ) -> None:
        super().__init__(
            module_name=module_name,
            rank=rank,
            lora_alpha=0,
            lora_a=lora_a,
            lora_b=lora_b,
152
            bias=bias,
153
            scaling=scaling,  # type: ignore
154
155
156
157
            embeddings_tensor=None,
        )
        self.lora_alphas = lora_alphas
        if scaling is None:
158
159
160
            self.scaling = [  # type: ignore
                lora_alpha / self.rank  # type: ignore # noqa
                for lora_alpha in self.lora_alphas
161
162
163
            ]

    @classmethod
164
    def pack(
165
        cls, loras: GenericSequence[Optional["LoRALayerWeights"]]
166
    ) -> "PackedLoRALayerWeights":
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
        """Pack a list of LoRAs into a single LoRA.

        If LoRA is None, it signifies that the submodule does not have a LoRA.
        """
        first_lora = next(lora for lora in loras if lora is not None)
        for lora in loras:
            if lora is None:
                continue
            lora.optimize()
        rank = first_lora.rank
        module_name = first_lora.module_name
        obj = cls(
            module_name,
            rank,
            [lora.lora_alpha if lora is not None else None for lora in loras],
            [lora.lora_a if lora is not None else None for lora in loras],
            [lora.lora_b if lora is not None else None for lora in loras],
184
            [lora.bias if lora is not None else None for lora in loras],
185
186
187
            scaling=[
                1 if lora is not None else None  # type: ignore
                for lora in loras
188
189
            ],
        )
190
191
192
193
194
        return obj

    def optimize(self) -> "PackedLoRALayerWeights":
        """Optimize the LoRA by merging the scaling into lora_b."""
        for i in range(len(self.lora_b)):
195
            if self.scaling[i] == 1 or self.lora_b[i] is None:  # type: ignore
196
                continue
197
198
            self.lora_b[i] *= self.scaling[i]  # type: ignore
            self.scaling[i] = 1  # type: ignore
199
200
201
202
203
204
205
206
207
208
209
210
211
        return self

    @property
    def input_dim(self) -> int:
        raise NotImplementedError()

    @property
    def output_dim(self) -> int:
        raise NotImplementedError()

    @property
    def is_packed(self) -> bool:
        return True