""" Copyright 2023-2024 SGLang Team Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ # Integrates "S-LoRA: Serving Thousands of Concurrent LoRA Adapters" # and "Punica: Multi-Tenant LoRA Serving" import re import torch from sglang.srt.lora.lora import LoRAAdapter, get_lora_layer from sglang.srt.lora.lora_config import LoRAConfig from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.utils import is_hip, replace_submodule # ROCm: flashinfer available later if not is_hip(): from flashinfer import SegmentGEMMWrapper def get_stacked_name(name): # origin name -> (name for A, name for B) params_mapping = { "q_proj": ("qkv_proj", "q_proj"), "k_proj": ("qkv_proj", "kv_proj"), "v_proj": ("qkv_proj", "kv_proj"), "gate_proj": ("gate_up_proj", "gate_up_proj"), "up_proj": ("gate_up_proj", "gate_up_proj"), } return params_mapping.get(name, (name, name)) def get_layer_id(name): match = re.search(r"layers\.(\d+)\.", name) if match is None: return None return int(match.group(1)) class LoRAManager: def __init__( self, base_model, lora_paths, base_hf_config, max_loras_per_batch, load_config, dtype, ): self.base_model = base_model self.lora_paths = lora_paths self.base_hf_config = base_hf_config self.max_loras_per_batch = max_loras_per_batch self.load_config = load_config self.dtype = dtype workspace_buffer = torch.empty(1 * 1024 * 1024, dtype=torch.int8, device="cuda") self.segment_gemm = SegmentGEMMWrapper(workspace_buffer) self.init_loras() self.init_lora_memory_pool() self.init_lora_batch() def match_target_modules(self, module_name): for target_module in self.target_modules: if module_name.split(".")[-1] == target_module: return True return False def get_target_modules(self): modules = [] for module_name, module in self.base_model.named_modules(): if self.match_target_modules(module_name): modules.append((module_name, module)) return modules def set_lora_module(self, module_name, module): lora_module = get_lora_layer( module, self.segment_gemm, self.max_lora_dim, self.scaling ) replace_submodule(self.base_model, module_name, lora_module) return lora_module def init_loras(self): # get configs and target modules self.configs = {} self.origin_target_modules = set() for name, path in self.lora_paths.items(): self.configs[name] = LoRAConfig(path) self.origin_target_modules = set(self.origin_target_modules) | set( self.configs[name].target_modules ) self.target_modules = set( [ self.base_model.get_module_name(module) for module in self.origin_target_modules ] ) self.target_weights = set( [get_stacked_name(module) for module in self.origin_target_modules] ) # load all weights to cpu self.loras = [] self.lora_id = {} for name in self.lora_paths.keys(): self.lora_id[name] = len(self.loras) self.loras.append( LoRAAdapter( name, self.configs[name], self.base_hf_config, self.load_config ) ) self.loras[-1].initialize_weights() # misc lora configs self.max_lora_dim = max([x.hf_config["r"] for x in self.configs.values()]) self.scaling = self.loras[0].scaling # FIXME remove the restrictions assert all(x.hf_config["r"] == self.max_lora_dim for x in self.configs.values()) assert all(x.scaling == self.scaling for x in self.loras) # monkey patch to use the LoRA version self.lora_modules = [] for module_name, module in self.get_target_modules(): self.lora_modules.append( (module_name, self.set_lora_module(module_name, module)) ) def init_lora_memory_pool(self): # preallocate lora memory pool self.A_buffer = {} self.B_buffer = {} num_layer = self.base_hf_config.num_hidden_layers for module_A, module_B in self.target_weights: # init A tensor, column_major=True hidden_dim_A, _ = self.base_model.get_hidden_dim(module_A) c = self.loras[-1].get_stacked_multiply(module_A) if module_A not in self.A_buffer: self.A_buffer[module_A] = [ torch.empty( ( self.max_loras_per_batch, self.max_lora_dim * c, hidden_dim_A, ), dtype=self.dtype, device="cuda", ) for i in range(num_layer) ] # init B tensor, column_major=True _, hidden_dim_B = self.base_model.get_hidden_dim(module_B) c = self.loras[-1].get_stacked_multiply(module_B) if module_B not in self.B_buffer: self.B_buffer[module_B] = [ torch.empty( ( self.max_loras_per_batch, hidden_dim_B * c, self.max_lora_dim, ), dtype=self.dtype, device="cuda", ) for i in range(num_layer) ] def init_lora_batch(self): self.active_uids = set() # set of active loras self.buffer_id = {} # lora uid -> idx in memory pool def get_weight_name(self, name, idx): for target_weight_name in self.target_weights: if target_weight_name[idx] in name: return target_weight_name[idx] def load_lora(self, uid, buffer_id): num_layer = self.base_hf_config.num_hidden_layers if uid is None: for i in range(num_layer): for k in self.A_buffer.keys(): self.A_buffer[k][i][buffer_id] *= 0 return for i in range(num_layer): layer_weights = self.loras[self.lora_id[uid]].layers[i].weights for name, weights in layer_weights.items(): if "lora_A" in name: lora_weight_name = self.get_weight_name(name, 0) if lora_weight_name: self.A_buffer[lora_weight_name][i][buffer_id].copy_(weights) else: lora_weight_name = self.get_weight_name(name, 1) if lora_weight_name: self.B_buffer[lora_weight_name][i][buffer_id].copy_(weights) def prepare_lora_batch(self, input_metadata: InputMetadata): # load active loras into lora memory pool cur_uids = set(input_metadata.lora_paths) assert len(cur_uids) <= self.max_loras_per_batch i = 0 evictable_uids = list(self.active_uids) for uid in cur_uids: if uid not in self.active_uids: while i < len(evictable_uids) and evictable_uids[i] in cur_uids: i += 1 if i < len(evictable_uids): self.active_uids.remove(evictable_uids[i]) self.buffer_id.pop(evictable_uids[i]) self.load_lora(uid, i) self.active_uids.add(uid) self.buffer_id[uid] = i i += 1 if cur_uids == set([None]): return # setup lora in forward modules bs = input_metadata.batch_size seg_lens = ( input_metadata.extend_seq_lens if input_metadata.forward_mode.is_extend() else torch.ones(bs) ) weight_indices = torch.empty((bs,), dtype=torch.int64, device="cuda") for i, lora_path in enumerate(input_metadata.lora_paths): weight_indices[i] = self.buffer_id[lora_path] for module_name, module in self.lora_modules: layer_id = get_layer_id(module_name) if "qkv_proj" not in module_name: weight_name = self.get_weight_name(module_name, 0) module.set_lora_info( self.A_buffer[weight_name][layer_id], self.B_buffer[weight_name][layer_id], bs, seg_lens, weight_indices, ) else: module.set_lora_info( self.A_buffer["qkv_proj"][layer_id], self.B_buffer["q_proj"][layer_id], self.B_buffer["kv_proj"][layer_id], bs, seg_lens, weight_indices, )