Unverified Commit 4b74c3fc authored by Lifu Huang's avatar Lifu Huang Committed by GitHub
Browse files

[chore] Clean up redundant lora_weight_names concept to simplify code (#9131)

parent ce3ca9b0
...@@ -32,8 +32,8 @@ from sglang.srt.lora.utils import ( ...@@ -32,8 +32,8 @@ from sglang.srt.lora.utils import (
LoRABatchInfo, LoRABatchInfo,
LoRAType, LoRAType,
get_layer_id, get_layer_id,
get_normalized_lora_weight_names, get_normalized_target_modules,
get_weight_name, get_target_module_name,
) )
from sglang.srt.managers.io_struct import LoRAUpdateResult from sglang.srt.managers.io_struct import LoRAUpdateResult
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
...@@ -350,12 +350,20 @@ class LoRAManager: ...@@ -350,12 +350,20 @@ class LoRAManager:
""" """
for layer_id, layer_modules in enumerate(self.lora_modules): for layer_id, layer_modules in enumerate(self.lora_modules):
for module_name, module in layer_modules.items(): for module_name, module in layer_modules.items():
weight_name = get_weight_name( target_module = get_target_module_name(
module_name, self.memory_pool.lora_weight_names module_name, self.memory_pool.target_modules
) )
module.set_lora_info( module.set_lora_info(
self.memory_pool.get_tensor(weight_name, layer_id, LoRAType.LORA_A), self.memory_pool.get_tensor(
self.memory_pool.get_tensor(weight_name, layer_id, LoRAType.LORA_B), target_module=target_module,
layer_id=layer_id,
lora_type=LoRAType.LORA_A,
),
self.memory_pool.get_tensor(
target_module=target_module,
layer_id=layer_id,
lora_type=LoRAType.LORA_B,
),
) )
def init_state( def init_state(
...@@ -380,7 +388,6 @@ class LoRAManager: ...@@ -380,7 +388,6 @@ class LoRAManager:
max_lora_rank=max_lora_rank, max_lora_rank=max_lora_rank,
target_modules=target_modules, target_modules=target_modules,
) )
self.init_lora_weight_names()
self.init_lora_modules() self.init_lora_modules()
self.init_memory_pool() self.init_memory_pool()
self.update_lora_info() self.update_lora_info()
...@@ -426,6 +433,7 @@ class LoRAManager: ...@@ -426,6 +433,7 @@ class LoRAManager:
"enable all support modules types. " "enable all support modules types. "
) )
self.target_modules.update(config.target_modules) self.target_modules.update(config.target_modules)
self.target_modules = get_normalized_target_modules(self.target_modules)
if max_lora_rank is not None: if max_lora_rank is not None:
self.max_lora_rank = max_lora_rank self.max_lora_rank = max_lora_rank
...@@ -435,15 +443,6 @@ class LoRAManager: ...@@ -435,15 +443,6 @@ class LoRAManager:
default=0, default=0,
) )
def init_lora_weight_names(self):
"""
Add new LoRA weight names if needed based on the current `self.configs`.
"""
self.lora_weight_names: Set[str] = get_normalized_lora_weight_names(
self.target_modules
)
def load_lora_weights(self, lora_ref: LoRARef): def load_lora_weights(self, lora_ref: LoRARef):
""" """
Load the weights of a LoRA adapter to CPU memory and conducts post-loading validation. Load the weights of a LoRA adapter to CPU memory and conducts post-loading validation.
...@@ -467,7 +466,7 @@ class LoRAManager: ...@@ -467,7 +466,7 @@ class LoRAManager:
tp_size=self.tp_size, tp_size=self.tp_size,
tp_rank=self.tp_rank, tp_rank=self.tp_rank,
max_lora_rank=self.max_lora_rank, max_lora_rank=self.max_lora_rank,
lora_weight_names=self.lora_weight_names, target_modules=self.target_modules,
base_model=self.base_model, base_model=self.base_model,
) )
...@@ -494,7 +493,7 @@ class LoRAManager: ...@@ -494,7 +493,7 @@ class LoRAManager:
continue continue
# The module should be converted if it is included in target_names # The module should be converted if it is included in target_names
if module_name.split(".")[-1] in self.lora_weight_names: if module_name.split(".")[-1] in self.target_modules:
layer_id = get_layer_id(module_name) layer_id = get_layer_id(module_name)
self.lora_modules[layer_id][module_name] = self.set_lora_module( self.lora_modules[layer_id][module_name] = self.set_lora_module(
module_name, module module_name, module
......
...@@ -13,9 +13,9 @@ from sglang.srt.lora.utils import ( ...@@ -13,9 +13,9 @@ from sglang.srt.lora.utils import (
ROW_PARALLELISM_LINEAR_LORA_NAMES, ROW_PARALLELISM_LINEAR_LORA_NAMES,
LoRAType, LoRAType,
get_hidden_dim, get_hidden_dim,
get_normalized_lora_weight_names, get_normalized_target_modules,
get_stacked_multiply, get_stacked_multiply,
get_weight_name, get_target_module_name,
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -52,7 +52,7 @@ class LoRAMemoryPool: ...@@ -52,7 +52,7 @@ class LoRAMemoryPool:
tp_size: int, tp_size: int,
tp_rank: int, tp_rank: int,
max_lora_rank: int, max_lora_rank: int,
lora_weight_names: Set[str], target_modules: Set[str],
base_model: torch.nn.Module, base_model: torch.nn.Module,
): ):
self.base_hf_config: AutoConfig = base_hf_config self.base_hf_config: AutoConfig = base_hf_config
...@@ -62,7 +62,7 @@ class LoRAMemoryPool: ...@@ -62,7 +62,7 @@ class LoRAMemoryPool:
self.tp_size: int = tp_size self.tp_size: int = tp_size
self.tp_rank: int = tp_rank self.tp_rank: int = tp_rank
self.max_lora_rank: int = max_lora_rank self.max_lora_rank: int = max_lora_rank
self.lora_weight_names: Set[str] = lora_weight_names self.target_modules: Set[str] = target_modules
# Both A_buffer and B_buffer maps lora weight names to its buffer space. # Both A_buffer and B_buffer maps lora weight names to its buffer space.
# A_buffer contains num_layer number of row-major tensors with shape # A_buffer contains num_layer number of row-major tensors with shape
...@@ -95,8 +95,8 @@ class LoRAMemoryPool: ...@@ -95,8 +95,8 @@ class LoRAMemoryPool:
""" """
if config.r > self.max_lora_rank: if config.r > self.max_lora_rank:
return False return False
weights = get_normalized_lora_weight_names(config.target_modules) target_module_names = get_normalized_target_modules(config.target_modules)
return weights.issubset(self.lora_weight_names) return target_module_names.issubset(self.target_modules)
if isinstance(config, LoRAConfig): if isinstance(config, LoRAConfig):
return _can_support(config) return _can_support(config)
...@@ -139,10 +139,10 @@ class LoRAMemoryPool: ...@@ -139,10 +139,10 @@ class LoRAMemoryPool:
def init_buffer( def init_buffer(
buffer: Dict[str, List[torch.Tensor]], buffer: Dict[str, List[torch.Tensor]],
lora_weight_names: Set[str], target_modules: Set[str],
get_lora_shape_fn: Callable[[str, torch.nn.Module, int], Tuple[int]], get_lora_shape_fn: Callable[[str, torch.nn.Module, int], Tuple[int]],
): ):
for module_name in lora_weight_names: for module_name in target_modules:
lora_shape = get_lora_shape_fn( lora_shape = get_lora_shape_fn(
module_name, base_model, self.max_lora_rank module_name, base_model, self.max_lora_rank
) )
...@@ -157,13 +157,13 @@ class LoRAMemoryPool: ...@@ -157,13 +157,13 @@ class LoRAMemoryPool:
init_buffer( init_buffer(
self.A_buffer, self.A_buffer,
self.lora_weight_names, self.target_modules,
self.get_lora_A_shape, self.get_lora_A_shape,
) )
init_buffer( init_buffer(
self.B_buffer, self.B_buffer,
self.lora_weight_names, self.target_modules,
self.get_lora_B_shape, self.get_lora_B_shape,
) )
...@@ -242,32 +242,34 @@ class LoRAMemoryPool: ...@@ -242,32 +242,34 @@ class LoRAMemoryPool:
for layer_id in range(self.num_layer): for layer_id in range(self.num_layer):
layer_weights = lora_adapter.layers[layer_id].weights layer_weights = lora_adapter.layers[layer_id].weights
temp_A_buffer: Dict[str, Optional[torch.Tensor]] = { temp_A_buffer: Dict[str, Optional[torch.Tensor]] = {
weight_name: None for weight_name in self.A_buffer target_module: None for target_module in self.A_buffer
} }
temp_B_buffer: Dict[str, Optional[torch.Tensor]] = { temp_B_buffer: Dict[str, Optional[torch.Tensor]] = {
weight_name: None for weight_name in self.B_buffer target_module: None for target_module in self.B_buffer
} }
for name, weights in layer_weights.items(): for name, weights in layer_weights.items():
lora_weight_name = get_weight_name(name, self.lora_weight_names) target_module = get_target_module_name(name, self.target_modules)
if "lora_A" in name: if "lora_A" in name:
temp_A_buffer[lora_weight_name] = weights temp_A_buffer[target_module] = weights
else: else:
temp_B_buffer[lora_weight_name] = weights temp_B_buffer[target_module] = weights
if self.tp_size > 1: if self.tp_size > 1:
cur_layer_modules = lora_modules[layer_id] cur_layer_modules = lora_modules[layer_id]
for module_name, module in cur_layer_modules.items(): for module_name, module in cur_layer_modules.items():
weight_name = get_weight_name(module_name, self.lora_weight_names) target_module = get_target_module_name(
module_name, self.target_modules
)
if temp_A_buffer[weight_name] is None: if temp_A_buffer[target_module] is None:
# Skip weight slicing if the weight is not present in the adapter # Skip weight slicing if the weight is not present in the adapter
continue continue
temp_A_buffer[weight_name] = module.slice_lora_a_weights( temp_A_buffer[target_module] = module.slice_lora_a_weights(
temp_A_buffer[weight_name], self.tp_rank temp_A_buffer[target_module], self.tp_rank
) )
temp_B_buffer[weight_name] = module.slice_lora_b_weights( temp_B_buffer[target_module] = module.slice_lora_b_weights(
temp_B_buffer[weight_name], self.tp_rank temp_B_buffer[target_module], self.tp_rank
) )
for name, weights in temp_A_buffer.items(): for name, weights in temp_A_buffer.items():
...@@ -282,12 +284,12 @@ class LoRAMemoryPool: ...@@ -282,12 +284,12 @@ class LoRAMemoryPool:
load_lora_weight_tensor(buffer_view, weights) load_lora_weight_tensor(buffer_view, weights)
def get_tensor( def get_tensor(
self, weight_name: str, layer_id: int, lora_type: LoRAType self, target_module: str, layer_id: int, lora_type: LoRAType
) -> torch.Tensor: ) -> torch.Tensor:
if lora_type == LoRAType.LORA_A: if lora_type == LoRAType.LORA_A:
return self.A_buffer[weight_name][layer_id] return self.A_buffer[target_module][layer_id]
return self.B_buffer[weight_name][layer_id] return self.B_buffer[target_module][layer_id]
def get_buffer_id(self, lora_uid: str): def get_buffer_id(self, lora_uid: str):
return self.uid_to_buffer_id[lora_uid] return self.uid_to_buffer_id[lora_uid]
...@@ -84,7 +84,7 @@ def get_hidden_dim( ...@@ -84,7 +84,7 @@ def get_hidden_dim(
raise NotImplementedError() raise NotImplementedError()
def get_normalized_lora_weight_names( def get_normalized_target_modules(
target_modules: Iterable[str], target_modules: Iterable[str],
) -> set[str]: ) -> set[str]:
""" """
...@@ -100,8 +100,8 @@ def get_normalized_lora_weight_names( ...@@ -100,8 +100,8 @@ def get_normalized_lora_weight_names(
result = set() result = set()
for name in target_modules: for name in target_modules:
weight_name = params_mapping.get(name, name) normalized_name = params_mapping.get(name, name)
result.add(weight_name) result.add(normalized_name)
return result return result
...@@ -116,20 +116,18 @@ def get_stacked_multiply(module_name: str) -> int: ...@@ -116,20 +116,18 @@ def get_stacked_multiply(module_name: str) -> int:
return stacked_rank[module_name] if module_name in stacked_rank else 1 return stacked_rank[module_name] if module_name in stacked_rank else 1
def get_weight_name( def get_target_module_name(full_module_name: str, target_modules: Set[str]) -> str:
target_name: str, lora_weight_names: Tuple[Set[str]]
) -> Optional[str]:
""" """
Get the weight name in lora_weight_names that can match target_name. Get the target module name in target_modules that can match full_module_name.
If there is a weight name in lora_weight_names that can match target_name, return this name If there is a target module name in target_modules that can match full_module_name, return this name
Else raise ValueError. Else raise ValueError.
""" """
for weight_name in lora_weight_names: for target_module in target_modules:
if weight_name in target_name: if target_module in full_module_name:
return weight_name return target_module
raise ValueError( raise ValueError(
f"Cannot find weight name for {target_name} in {lora_weight_names}" f"Cannot find target module name for {full_module_name} in {target_modules}"
) )
......
...@@ -2874,6 +2874,8 @@ SUPPORTED_LORA_TARGET_MODULES = [ ...@@ -2874,6 +2874,8 @@ SUPPORTED_LORA_TARGET_MODULES = [
"gate_proj", "gate_proj",
"up_proj", "up_proj",
"down_proj", "down_proj",
"qkv_proj",
"gate_up_proj",
] ]
LORA_TARGET_ALL_MODULES = "all" LORA_TARGET_ALL_MODULES = "all"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment