Unverified Commit 4b74c3fc authored by Lifu Huang's avatar Lifu Huang Committed by GitHub
Browse files

[chore] Clean up redundant lora_weight_names concept to simplify code (#9131)

parent ce3ca9b0
......@@ -32,8 +32,8 @@ from sglang.srt.lora.utils import (
LoRABatchInfo,
LoRAType,
get_layer_id,
get_normalized_lora_weight_names,
get_weight_name,
get_normalized_target_modules,
get_target_module_name,
)
from sglang.srt.managers.io_struct import LoRAUpdateResult
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
......@@ -350,12 +350,20 @@ class LoRAManager:
"""
for layer_id, layer_modules in enumerate(self.lora_modules):
for module_name, module in layer_modules.items():
weight_name = get_weight_name(
module_name, self.memory_pool.lora_weight_names
target_module = get_target_module_name(
module_name, self.memory_pool.target_modules
)
module.set_lora_info(
self.memory_pool.get_tensor(weight_name, layer_id, LoRAType.LORA_A),
self.memory_pool.get_tensor(weight_name, layer_id, LoRAType.LORA_B),
self.memory_pool.get_tensor(
target_module=target_module,
layer_id=layer_id,
lora_type=LoRAType.LORA_A,
),
self.memory_pool.get_tensor(
target_module=target_module,
layer_id=layer_id,
lora_type=LoRAType.LORA_B,
),
)
def init_state(
......@@ -380,7 +388,6 @@ class LoRAManager:
max_lora_rank=max_lora_rank,
target_modules=target_modules,
)
self.init_lora_weight_names()
self.init_lora_modules()
self.init_memory_pool()
self.update_lora_info()
......@@ -426,6 +433,7 @@ class LoRAManager:
"enable all support modules types. "
)
self.target_modules.update(config.target_modules)
self.target_modules = get_normalized_target_modules(self.target_modules)
if max_lora_rank is not None:
self.max_lora_rank = max_lora_rank
......@@ -435,15 +443,6 @@ class LoRAManager:
default=0,
)
def init_lora_weight_names(self):
"""
Add new LoRA weight names if needed based on the current `self.configs`.
"""
self.lora_weight_names: Set[str] = get_normalized_lora_weight_names(
self.target_modules
)
def load_lora_weights(self, lora_ref: LoRARef):
"""
Load the weights of a LoRA adapter to CPU memory and conducts post-loading validation.
......@@ -467,7 +466,7 @@ class LoRAManager:
tp_size=self.tp_size,
tp_rank=self.tp_rank,
max_lora_rank=self.max_lora_rank,
lora_weight_names=self.lora_weight_names,
target_modules=self.target_modules,
base_model=self.base_model,
)
......@@ -494,7 +493,7 @@ class LoRAManager:
continue
# The module should be converted if it is included in target_names
if module_name.split(".")[-1] in self.lora_weight_names:
if module_name.split(".")[-1] in self.target_modules:
layer_id = get_layer_id(module_name)
self.lora_modules[layer_id][module_name] = self.set_lora_module(
module_name, module
......
......@@ -13,9 +13,9 @@ from sglang.srt.lora.utils import (
ROW_PARALLELISM_LINEAR_LORA_NAMES,
LoRAType,
get_hidden_dim,
get_normalized_lora_weight_names,
get_normalized_target_modules,
get_stacked_multiply,
get_weight_name,
get_target_module_name,
)
logger = logging.getLogger(__name__)
......@@ -52,7 +52,7 @@ class LoRAMemoryPool:
tp_size: int,
tp_rank: int,
max_lora_rank: int,
lora_weight_names: Set[str],
target_modules: Set[str],
base_model: torch.nn.Module,
):
self.base_hf_config: AutoConfig = base_hf_config
......@@ -62,7 +62,7 @@ class LoRAMemoryPool:
self.tp_size: int = tp_size
self.tp_rank: int = tp_rank
self.max_lora_rank: int = max_lora_rank
self.lora_weight_names: Set[str] = lora_weight_names
self.target_modules: Set[str] = target_modules
# Both A_buffer and B_buffer maps lora weight names to its buffer space.
# A_buffer contains num_layer number of row-major tensors with shape
......@@ -95,8 +95,8 @@ class LoRAMemoryPool:
"""
if config.r > self.max_lora_rank:
return False
weights = get_normalized_lora_weight_names(config.target_modules)
return weights.issubset(self.lora_weight_names)
target_module_names = get_normalized_target_modules(config.target_modules)
return target_module_names.issubset(self.target_modules)
if isinstance(config, LoRAConfig):
return _can_support(config)
......@@ -139,10 +139,10 @@ class LoRAMemoryPool:
def init_buffer(
buffer: Dict[str, List[torch.Tensor]],
lora_weight_names: Set[str],
target_modules: Set[str],
get_lora_shape_fn: Callable[[str, torch.nn.Module, int], Tuple[int]],
):
for module_name in lora_weight_names:
for module_name in target_modules:
lora_shape = get_lora_shape_fn(
module_name, base_model, self.max_lora_rank
)
......@@ -157,13 +157,13 @@ class LoRAMemoryPool:
init_buffer(
self.A_buffer,
self.lora_weight_names,
self.target_modules,
self.get_lora_A_shape,
)
init_buffer(
self.B_buffer,
self.lora_weight_names,
self.target_modules,
self.get_lora_B_shape,
)
......@@ -242,32 +242,34 @@ class LoRAMemoryPool:
for layer_id in range(self.num_layer):
layer_weights = lora_adapter.layers[layer_id].weights
temp_A_buffer: Dict[str, Optional[torch.Tensor]] = {
weight_name: None for weight_name in self.A_buffer
target_module: None for target_module in self.A_buffer
}
temp_B_buffer: Dict[str, Optional[torch.Tensor]] = {
weight_name: None for weight_name in self.B_buffer
target_module: None for target_module in self.B_buffer
}
for name, weights in layer_weights.items():
lora_weight_name = get_weight_name(name, self.lora_weight_names)
target_module = get_target_module_name(name, self.target_modules)
if "lora_A" in name:
temp_A_buffer[lora_weight_name] = weights
temp_A_buffer[target_module] = weights
else:
temp_B_buffer[lora_weight_name] = weights
temp_B_buffer[target_module] = weights
if self.tp_size > 1:
cur_layer_modules = lora_modules[layer_id]
for module_name, module in cur_layer_modules.items():
weight_name = get_weight_name(module_name, self.lora_weight_names)
target_module = get_target_module_name(
module_name, self.target_modules
)
if temp_A_buffer[weight_name] is None:
if temp_A_buffer[target_module] is None:
# Skip weight slicing if the weight is not present in the adapter
continue
temp_A_buffer[weight_name] = module.slice_lora_a_weights(
temp_A_buffer[weight_name], self.tp_rank
temp_A_buffer[target_module] = module.slice_lora_a_weights(
temp_A_buffer[target_module], self.tp_rank
)
temp_B_buffer[weight_name] = module.slice_lora_b_weights(
temp_B_buffer[weight_name], self.tp_rank
temp_B_buffer[target_module] = module.slice_lora_b_weights(
temp_B_buffer[target_module], self.tp_rank
)
for name, weights in temp_A_buffer.items():
......@@ -282,12 +284,12 @@ class LoRAMemoryPool:
load_lora_weight_tensor(buffer_view, weights)
def get_tensor(
self, weight_name: str, layer_id: int, lora_type: LoRAType
self, target_module: str, layer_id: int, lora_type: LoRAType
) -> torch.Tensor:
if lora_type == LoRAType.LORA_A:
return self.A_buffer[weight_name][layer_id]
return self.A_buffer[target_module][layer_id]
return self.B_buffer[weight_name][layer_id]
return self.B_buffer[target_module][layer_id]
def get_buffer_id(self, lora_uid: str):
return self.uid_to_buffer_id[lora_uid]
......@@ -84,7 +84,7 @@ def get_hidden_dim(
raise NotImplementedError()
def get_normalized_lora_weight_names(
def get_normalized_target_modules(
target_modules: Iterable[str],
) -> set[str]:
"""
......@@ -100,8 +100,8 @@ def get_normalized_lora_weight_names(
result = set()
for name in target_modules:
weight_name = params_mapping.get(name, name)
result.add(weight_name)
normalized_name = params_mapping.get(name, name)
result.add(normalized_name)
return result
......@@ -116,20 +116,18 @@ def get_stacked_multiply(module_name: str) -> int:
return stacked_rank[module_name] if module_name in stacked_rank else 1
def get_weight_name(
target_name: str, lora_weight_names: Tuple[Set[str]]
) -> Optional[str]:
def get_target_module_name(full_module_name: str, target_modules: Set[str]) -> str:
"""
Get the weight name in lora_weight_names that can match target_name.
Get the target module name in target_modules that can match full_module_name.
If there is a weight name in lora_weight_names that can match target_name, return this name
If there is a target module name in target_modules that can match full_module_name, return this name
Else raise ValueError.
"""
for weight_name in lora_weight_names:
if weight_name in target_name:
return weight_name
for target_module in target_modules:
if target_module in full_module_name:
return target_module
raise ValueError(
f"Cannot find weight name for {target_name} in {lora_weight_names}"
f"Cannot find target module name for {full_module_name} in {target_modules}"
)
......
......@@ -2874,6 +2874,8 @@ SUPPORTED_LORA_TARGET_MODULES = [
"gate_proj",
"up_proj",
"down_proj",
"qkv_proj",
"gate_up_proj",
]
LORA_TARGET_ALL_MODULES = "all"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment