lora_weights.py 5.65 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
from collections.abc import Sequence as GenericSequence
from typing import Optional
6
7

import torch
8
import torch.types
9

10
from vllm.lora.peft_helper import PEFTHelper
11
from vllm.utils import is_pin_memory_available
12
13
14
15
16
17
18
19
20
21
22
23


class LoRALayerWeights:
    """LoRA weights for a layer composed of two low rank matrixes."""

    def __init__(
        self,
        module_name: str,
        rank: int,
        lora_alpha: int,
        lora_a: torch.Tensor,
        lora_b: torch.Tensor,
24
25
        embeddings_tensor: torch.Tensor | None = None,
        scaling: float | None = None,
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
    ) -> None:
        self.module_name = module_name
        self.rank = rank
        self.lora_alpha = lora_alpha
        self.lora_a = lora_a
        self.lora_b = lora_b
        self.embeddings_tensor = embeddings_tensor

        if scaling is None:
            self.scaling = self.lora_alpha / self.rank
        else:
            self.scaling = scaling

    def optimize(self) -> "LoRALayerWeights":
        """Optimize the LoRA by merging the scaling into lora_b."""
        if self.scaling == 1:
42
            return self
43
44
45
46
47
48
        self.lora_b *= self.scaling
        self.scaling = 1
        return self

    @property
    def input_dim(self) -> int:
49
        return self.lora_a.shape[1]
50
51
52

    @property
    def output_dim(self) -> int:
53
        return self.lora_b.shape[0]
54
55
56
57
58
59
60

    @property
    def is_packed(self) -> bool:
        return False

    @property
    def extra_vocab_size(self) -> int:
61
62
63
        return (
            self.embeddings_tensor.shape[0] if self.embeddings_tensor is not None else 0
        )
64

65
66
67
68
69
    @classmethod
    def from_config(
        cls,
        module_name: str,
        peft_helper: PEFTHelper,
70
        embeddings_tensor: torch.Tensor | None = None,
71
    ) -> "LoRALayerWeights":
72
        # lora_a and lora_b are set to None for config-based construction
73
74
75
76
77
78
79
80
81
        return cls(
            module_name,
            peft_helper.r,
            peft_helper.lora_alpha,
            None,
            None,
            embeddings_tensor,
            peft_helper.vllm_lora_scaling_factor,
        )
82

83
84
    @classmethod
    def create_dummy_lora_weights(
85
86
87
88
89
90
91
        cls,
        module_name: str,
        input_dim: int,
        output_dim: int,
        rank: int,
        dtype: torch.dtype,
        device: torch.types.Device,
92
        embeddings_tensor_dim: int | None = None,
93
    ) -> "LoRALayerWeights":
94
        pin_memory = str(device) == "cpu" and is_pin_memory_available()
95
96
97
98
99
100
        lora_a = torch.zeros(
            [rank, input_dim], dtype=dtype, device=device, pin_memory=pin_memory
        )
        lora_b = torch.zeros(
            [output_dim, rank], dtype=dtype, device=device, pin_memory=pin_memory
        )
101

102
103
104
105
106
107
108
109
110
111
112
        embeddings_tensor = (
            torch.rand(
                10,
                embeddings_tensor_dim,
                dtype=dtype,
                device=device,
                pin_memory=pin_memory,
            )
            if embeddings_tensor_dim
            else None
        )
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
        return cls(
            module_name,
            rank=rank,
            lora_alpha=1,
            lora_a=lora_a,
            lora_b=lora_b,
            embeddings_tensor=embeddings_tensor,
        )


class PackedLoRALayerWeights(LoRALayerWeights):
    """LoRA used for packed layers (eg. qkv_proj)."""

    def __init__(
        self,
        module_name: str,
        rank: int,
130
131
132
133
        lora_alphas: list[int | None],
        lora_a: list[torch.Tensor | None],
        lora_b: list[torch.Tensor | None],
        scaling: list[float] | None = None,
134
135
136
137
138
139
140
    ) -> None:
        super().__init__(
            module_name=module_name,
            rank=rank,
            lora_alpha=0,
            lora_a=lora_a,
            lora_b=lora_b,
141
            scaling=scaling,  # type: ignore
142
143
144
145
            embeddings_tensor=None,
        )
        self.lora_alphas = lora_alphas
        if scaling is None:
146
147
148
            self.scaling = [  # type: ignore
                lora_alpha / self.rank  # type: ignore # noqa
                for lora_alpha in self.lora_alphas
149
150
151
            ]

    @classmethod
152
    def pack(
153
        cls, loras: GenericSequence[Optional["LoRALayerWeights"]]
154
    ) -> "PackedLoRALayerWeights":
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
        """Pack a list of LoRAs into a single LoRA.

        If LoRA is None, it signifies that the submodule does not have a LoRA.
        """
        first_lora = next(lora for lora in loras if lora is not None)
        for lora in loras:
            if lora is None:
                continue
            lora.optimize()
        rank = first_lora.rank
        module_name = first_lora.module_name
        obj = cls(
            module_name,
            rank,
            [lora.lora_alpha if lora is not None else None for lora in loras],
            [lora.lora_a if lora is not None else None for lora in loras],
            [lora.lora_b if lora is not None else None for lora in loras],
172
173
174
            scaling=[
                1 if lora is not None else None  # type: ignore
                for lora in loras
175
176
            ],
        )
177
178
179
180
181
        return obj

    def optimize(self) -> "PackedLoRALayerWeights":
        """Optimize the LoRA by merging the scaling into lora_b."""
        for i in range(len(self.lora_b)):
182
            if self.scaling[i] == 1 or self.lora_b[i] is None:  # type: ignore
183
                continue
184
185
            self.lora_b[i] *= self.scaling[i]  # type: ignore
            self.scaling[i] = 1  # type: ignore
186
187
188
189
190
191
192
193
194
195
196
197
198
        return self

    @property
    def input_dim(self) -> int:
        raise NotImplementedError()

    @property
    def output_dim(self) -> int:
        raise NotImplementedError()

    @property
    def is_packed(self) -> bool:
        return True