lora.py 9.27 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14
15
16
17
18
19
20

# Integrates "S-LoRA: Serving Thousands of Concurrent LoRA Adapters"
# and "Punica: Multi-Tenant LoRA Serving"

# LoRA layers class inheritance adapted from:
# https://github.com/vllm-project/vllm/blob/4abf6336ec65c270343eb895e7b18786e9274176/vllm/lora/layers.py

21
import logging
22
import re
23
from typing import Dict, List
24
25
26
27

import torch
from torch import nn

28
29
from sglang.srt.configs.load_config import LoadConfig
from sglang.srt.hf_transformers_utils import AutoConfig
Lianmin Zheng's avatar
Lianmin Zheng committed
30
from sglang.srt.lora.backend.base_backend import BaseLoRABackend
31
from sglang.srt.lora.lora_config import LoRAConfig
32
from sglang.srt.model_loader.loader import DefaultModelLoader
33

34
35
logger = logging.getLogger(__name__)

36
37

class LoRALayer(nn.Module):
38
    def __init__(self, config: LoRAConfig, base_hf_config: AutoConfig):
39
        super().__init__()
40
41
        self.config: LoRAConfig = config
        self.base_hf_config: AutoConfig = base_hf_config
42

43
        # lora weights in cpu. The weights are loaded from checkpoint.
44
        self.weights: Dict[str, torch.Tensor] = {}
45
46
47


class LoRAAdapter(nn.Module):
48
49
50
51
52
53
54
55
    def __init__(
        self,
        uid: str,
        config: LoRAConfig,
        base_hf_config: AutoConfig,
        load_config: LoadConfig,
        lora_backend: BaseLoRABackend,
    ):
56
        super().__init__()
57
58
        self.uid: str = uid
        self.config: LoRAConfig = config
59
        assert self.config.hf_config["peft_type"].lower() == "lora"
60
61
62
63
        self.base_hf_config: AutoConfig = base_hf_config
        self.load_config: LoadConfig = load_config
        self.lora_backend: BaseLoRABackend = lora_backend
        self.scaling: float = self.config.lora_alpha / self.config.r
64

65
        self.layers: List[LoRALayer] = nn.ModuleList(
66
67
68
69
70
71
            [
                LoRALayer(config, base_hf_config)
                for i in range(base_hf_config.num_hidden_layers)
            ]
        )

72
        self.weights: Dict[str, torch.Tensor] = {}
73
74
75
76
77
78
79

    # initialize the LoRA weights to cpu
    def initialize_weights(self):
        model_path = self.config.path
        loader = DefaultModelLoader(self.load_config)
        revision = getattr(self.config.hf_config, "revision", None)
        for name, loaded_weight in loader._get_weights_iterator(
80
81
82
            DefaultModelLoader.Source(
                model_path, revision=revision, fall_back_to_pt=True
            )
83
84
85
86
87
88
89
90
91
92
93
94
        ):
            match = re.search(r"layers\.(\d+)\.", name)
            if match is not None:
                layer_id = int(match.group(1))
                self.layers[layer_id].weights[name] = loaded_weight.cpu()
            else:
                self.weights[name] = loaded_weight.cpu()

        # stack kv_proj and gate_up_proj
        for i in range(self.base_hf_config.num_hidden_layers):
            layer = self.layers[i]
            weight_names = [name for name, _ in layer.weights.items()]
95
96
            self.normalize_qkv_proj(weight_names, layer.weights)
            self.normalize_gate_up_proj(weight_names, layer.weights)
97

98
99
100
    def normalize_qkv_proj(
        self, weight_names: List[str], weights: Dict[str, torch.Tensor]
    ):
101
        # Collect target q/k/v modules. This process is necessary since there might be no lora attached to k_proj
102
103
104
105
106
107
108
109
        target_module = set()
        for weight_name in weight_names:
            if "k_proj" in weight_name:
                target_module.add("k_proj")
            if "q_proj" in weight_name:
                target_module.add("q_proj")
            if "v_proj" in weight_name:
                target_module.add("v_proj")
110
111
            if "qkv_proj" in weight_name:
                target_module.add("qkv_proj")
112
113
114
115
        if len(target_module) == 0:
            return

        for weight_name in weight_names:
116
            # We assume every lora adaptor should contain lora modules for q_proj
117
118
119
120
121
122
123
            if "q_proj" in weight_name:
                q_name = weight_name
                k_name = weight_name.replace("q_proj", "k_proj")
                v_name = weight_name.replace("q_proj", "v_proj")
                kv_name = weight_name.replace("q_proj", "kv_proj")
                qkv_name = weight_name.replace("q_proj", "qkv_proj")

124
                # If k_proj doesn't have lora, initialize it to zero
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
                k_proj_weight = (
                    weights[k_name]
                    if "k_proj" in target_module
                    else torch.zeros_like(weights[v_name])
                )
                if "lora_A" in weight_name:
                    weights[qkv_name] = torch.cat(
                        (
                            weights[q_name],
                            k_proj_weight,
                            weights[v_name],
                        ),
                        0,
                    )
                    weights.pop(q_name)
                    if "k_proj" in target_module:
                        weights.pop(k_name)
                    weights.pop(v_name)
                else:
                    weights[kv_name] = torch.stack(
                        [
                            k_proj_weight,
                            weights[v_name],
                        ],
                        dim=0,
                    )
                    if "k_proj" in target_module:
                        weights.pop(k_name)
                    weights.pop(v_name)
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
            elif "qkv_proj" in weight_name:
                # If qkv_proj is already stacked, we normalize it following the SGL convention.
                qkv_name = weight_name
                q_name = weight_name.replace("qkv_proj", "q_proj")
                k_name = weight_name.replace("qkv_proj", "k_proj")
                v_name = weight_name.replace("qkv_proj", "v_proj")
                kv_name = weight_name.replace("qkv_proj", "kv_proj")
                if "lora_A" in weight_name:
                    weights[qkv_name] = weights[qkv_name].repeat(3, 1)
                else:
                    head_size = (
                        self.base_hf_config.hidden_size
                        // self.base_hf_config.num_attention_heads
                    )
                    weights[q_name], weights[kv_name] = torch.split(
                        weights[qkv_name],
                        [
                            head_size * self.base_hf_config.num_attention_heads,
                            head_size * self.base_hf_config.num_key_value_heads * 2,
                        ],
                        dim=0,
                    )
176

177
    def normalize_gate_up_proj(
178
179
180
181
182
183
        self, weight_names: List[str], weights: Dict[str, torch.Tensor]
    ):
        for weight_name in weight_names:
            if "gate_proj" in weight_name:
                up_name = weight_name.replace("gate_proj", "up_proj")
                gate_up_name = weight_name.replace("gate_proj", "gate_up_proj")
184
185
186
187
188
189
190
191
192
193
194
195
                if up_name not in weights:
                    logger.warning(
                        f"Gate projection {weight_name} does not have a corresponding up projection {up_name}. "
                        f"Initializing up projection to zero."
                    )
                    weights[up_name] = torch.zeros_like(weights[weight_name])
                    # FIXME: Add gate-only support for flashinfer in future implementations
                    assert self.lora_backend.name == "triton", (
                        f"LoRA weight initialization currently only supported for 'triton' backend. "
                        f"Received backend: {self.lora_backend.name}. Please verify your backend configuration "
                        f"or consider implementing custom initialization logic for other backends."
                    )
196
197
198
199
200
201
202
203
204
                if "lora_A" in weight_name:
                    weights[gate_up_name] = torch.cat(
                        (weights[weight_name], weights[up_name]), 0
                    )
                else:
                    weights[gate_up_name] = torch.stack(
                        [weights[weight_name], weights[up_name]], dim=0
                    )
                weights.pop(weight_name)
205
206
                if up_name in weights:
                    weights.pop(up_name)
207
208
209
210
211
            elif "gate_up_proj" in weight_name:
                # If gate_up_proj is already stacked, we normalize it following the SGL convention
                gate_up_name = weight_name
                if "lora_A" in weight_name:
                    weights[gate_up_name] = weights[gate_up_name].repeat(2, 1)
212
213
214
215
216
217
218
219
220
                else:
                    output_dim = weights[gate_up_name].shape[0] // 2
                    weights[gate_up_name] = torch.stack(
                        [
                            weights[gate_up_name][:output_dim, :],
                            weights[gate_up_name][output_dim:, :],
                        ],
                        dim=0,
                    )