"examples/online_serving/run_cluster.sh" did not exist on "8228a79e9d80ff2b5217639e5c8b7d7825e9cbd7"
lora_manager.py 8.98 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14
15
16
17

# Integrates "S-LoRA: Serving Thousands of Concurrent LoRA Adapters"
# and "Punica: Multi-Tenant LoRA Serving"

18
import logging
19
from typing import Dict, List, Set, Tuple
20
21
22

import torch

23
24
from sglang.srt.configs.load_config import LoadConfig
from sglang.srt.hf_transformers_utils import AutoConfig
Lianmin Zheng's avatar
Lianmin Zheng committed
25
from sglang.srt.lora.backend.base_backend import BaseLoRABackend, get_backend_from_name
26
from sglang.srt.lora.layers import BaseLayerWithLoRA, get_lora_layer
27
from sglang.srt.lora.lora import LoRAAdapter
28
from sglang.srt.lora.lora_config import LoRAConfig
29
30
31
32
33
34
35
36
37
from sglang.srt.lora.mem_pool import LoRAMemoryPool
from sglang.srt.lora.utils import (
    LoRABatchInfo,
    LoRAType,
    get_customized_names_from_hf_names,
    get_layer_id,
    get_stacked_name,
    get_weight_name,
)
38
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
39
from sglang.srt.utils import replace_submodule
40

41
42
logger = logging.getLogger(__name__)

43
44
45
46

class LoRAManager:
    def __init__(
        self,
47
48
49
50
51
52
53
        base_model: torch.nn.Module,
        lora_paths: Dict[str, str],
        base_hf_config: AutoConfig,
        max_loras_per_batch: int,
        load_config: LoadConfig,
        dtype: torch.dtype,
        lora_backend: str = "triton",
54
55
        tp_size: int = 1,
        tp_rank: int = 0,
56
    ):
57
58
59
60
61
62
        self.base_model: torch.nn.Module = base_model
        self.lora_paths: Dict[str, str] = lora_paths
        self.base_hf_config: AutoConfig = base_hf_config
        self.max_loras_per_batch: int = max_loras_per_batch
        self.load_config: LoadConfig = load_config
        self.dtype: torch.dtype = dtype
63
64
65
        self.device: torch.device = next(self.base_model.parameters()).device
        self.tp_size: int = tp_size
        self.tp_rank: int = tp_rank
66
67
68

        # LoRA backend for running sgemm kernels
        logger.info(f"Using {lora_backend} as backend of LoRA kernels.")
69
        backend_type = get_backend_from_name(lora_backend)
70
        self.lora_backend: BaseLoRABackend = backend_type(lora_backend)
71
72
73
74
75

        self.init_loras()
        self.init_lora_memory_pool()

    def init_loras(self):
76
77
78
79
80
81
        # Config of each LoRA adapter
        self.configs: Dict[str, LoRAConfig] = {}

        # Target module names in huggingface lora configs.
        # e.g., {"k_proj", "q_proj", "v_proj", "o_proj"}
        self.hf_target_names: Set[str] = set()
82
83
        for name, path in self.lora_paths.items():
            self.configs[name] = LoRAConfig(path)
84
            self.hf_target_names.update(self.configs[name].target_modules)
85
86
87
88
89

        # Target lora weight names for lora_a and lora_b modules repectively.
        # e.g., {("qkv_proj", "q_proj"), ("qkv_proj", "kv_proj")}
        self.lora_weight_names: Set[Tuple[str]] = set(
            [get_stacked_name(module) for module in self.hf_target_names]
90
91
92
        )

        # load all weights to cpu
93
        self.loras: Dict[str, LoRAAdapter] = {}
94
        for name in self.lora_paths.keys():
95
96
97
98
99
100
            lora_adapter = LoRAAdapter(
                name,
                self.configs[name],
                self.base_hf_config,
                self.load_config,
                self.lora_backend,
101
            )
102
103
            lora_adapter.initialize_weights()
            self.loras[name] = lora_adapter
104
105

        # misc lora configs
106
        self.max_lora_dim: int = max([x.hf_config["r"] for x in self.configs.values()])
107
108
109
110
111
112
113

        if self.lora_backend == "flashinfer":
            # FIXME remove the restrictions after supporting multi-rank for flashinfer backend
            max_lora_dim = max([x.hf_config["r"] for x in self.configs.values()])
            scaling = list(self.loras.values())[0].scaling
            assert all(x.hf_config["r"] == max_lora_dim for x in self.configs.values())
            assert all(x.scaling == scaling for x in self.loras.values())
114

115
116
        # Convert original model layers to layers with LoRA
        self.convert_to_lora_layers()
117
118

    def init_lora_memory_pool(self):
119
120
        # Initialize memory pool
        self.memory_pool = LoRAMemoryPool(
121
122
123
124
125
126
127
            self.base_hf_config,
            self.max_loras_per_batch,
            self.max_lora_dim,
            self.dtype,
            self.tp_size,
            self.tp_rank,
            self.lora_modules,
128
        )
129

130
131
        # Initialize target lora modules in memory pool
        self.memory_pool.init_buffers(self.lora_weight_names, self.base_model)
132

133
    def prepare_lora_batch(self, forward_batch: ForwardBatch):
134
        # load active loras into lora memory pool
135
        cur_uids = set(forward_batch.lora_paths)
136
        assert len(cur_uids) <= self.max_loras_per_batch
137
        self.memory_pool.prepare_lora_batch(cur_uids, self.loras)
138

139
140
141
142
        # FIXME: Handle lora uid with None more safely
        if cur_uids == set([None]):
            return

143
        # set up batch info shared by all lora moruldes
144
        bs = forward_batch.batch_size
145
        seg_lens = (
146
147
            forward_batch.extend_seq_lens
            if forward_batch.forward_mode.is_extend()
148
            else torch.ones(bs, device=self.device)
149
        )
150
        seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device=self.device)
151
        seg_indptr[1:] = torch.cumsum(seg_lens, dim=0)
152
        max_len = int(torch.max(seg_lens))
153
        weight_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
154
155
156
157
158
159
160

        lora_ranks = torch.empty(
            (self.max_loras_per_batch,), dtype=torch.int64, device="cuda"
        )
        scalings = torch.empty(
            (self.max_loras_per_batch,), dtype=torch.float, device="cuda"
        )
161
        for i, lora_path in enumerate(forward_batch.lora_paths):
162
            weight_indices[i] = self.memory_pool.get_buffer_id(lora_path)
163
164
165
            lora = self.loras[lora_path]
            lora_ranks[weight_indices[i]] = lora.config.hf_config["r"]
            scalings[weight_indices[i]] = lora.scaling
166

167
        batch_info = LoRABatchInfo(
168
169
170
171
172
            bs=bs,
            seg_lens=seg_lens,
            seg_indptr=seg_indptr,
            max_len=max_len,
            weight_indices=weight_indices,
173
174
            lora_ranks=lora_ranks,
            scalings=scalings,
175
176
177
178
        )
        self.lora_backend.set_batch_info(batch_info)

        # call set_lora_info for each lora modules
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
        for layer_id, modules in self.lora_modules.items():
            for module_name, module in modules:
                if "qkv_proj" in module_name:
                    module.set_lora_info(
                        self.memory_pool.get_tensor(
                            "qkv_proj", layer_id, LoRAType.LORA_A
                        ),
                        self.memory_pool.get_tensor(
                            "q_proj", layer_id, LoRAType.LORA_B
                        ),
                        self.memory_pool.get_tensor(
                            "kv_proj", layer_id, LoRAType.LORA_B
                        ),
                    )
                else:
                    weight_name = get_weight_name(
                        module_name, self.lora_weight_names, LoRAType.LORA_A
                    )
                    module.set_lora_info(
                        self.memory_pool.get_tensor(
                            weight_name, layer_id, LoRAType.LORA_A
                        ),
                        self.memory_pool.get_tensor(
                            weight_name, layer_id, LoRAType.LORA_B
                        ),
                    )
205
206

    def set_lora_module(self, module_name, module):
207
        lora_module = get_lora_layer(module, self.lora_backend)
208
209
210
211
212
213
214
215
216
217
218
        replace_submodule(self.base_model, module_name, lora_module)
        return lora_module

    def convert_to_lora_layers(self):
        # Target module names of customized layers defined in python/sglang/srt/layers
        # e.g., {"qkv_proj", "o_proj"}
        customized_target_names = get_customized_names_from_hf_names(
            self.hf_target_names, self.base_model
        )

        # Monkey patch to use the LoRA version layers
219
220
221
        self.lora_modules: Dict[int, List[Tuple[str, BaseLayerWithLoRA]]] = {
            i: [] for i in range(self.base_hf_config.num_hidden_layers)
        }
222
223
224
        for module_name, module in self.base_model.named_modules():
            # The module should be converted if it is included in target_names
            if module_name.split(".")[-1] in customized_target_names:
225
226
                layer_id = get_layer_id(module_name)
                self.lora_modules[layer_id].append(
227
                    (module_name, self.set_lora_module(module_name, module))
228
                )