Unverified Commit 477a101c authored by Lifu Huang's avatar Lifu Huang Committed by GitHub
Browse files

Refactor LoRA handling to support adapter tensors in fused format (#6585)

parent 1a8f5f68
......@@ -92,11 +92,12 @@ class LoRAAdapter(nn.Module):
for i in range(self.base_hf_config.num_hidden_layers):
layer = self.layers[i]
weight_names = [name for name, _ in layer.weights.items()]
self.stack_qkv_proj(weight_names, layer.weights)
self.stack_gate_up_proj(weight_names, layer.weights)
def stack_qkv_proj(self, weight_names: List[str], weights: Dict[str, torch.Tensor]):
self.normalize_qkv_proj(weight_names, layer.weights)
self.normalize_gate_up_proj(weight_names, layer.weights)
def normalize_qkv_proj(
self, weight_names: List[str], weights: Dict[str, torch.Tensor]
):
# Collect target q/k/v modules. This process is necessary since there might be no lora attached to k_proj
target_module = set()
for weight_name in weight_names:
......@@ -106,6 +107,8 @@ class LoRAAdapter(nn.Module):
target_module.add("q_proj")
if "v_proj" in weight_name:
target_module.add("v_proj")
if "qkv_proj" in weight_name:
target_module.add("qkv_proj")
if len(target_module) == 0:
return
......@@ -148,8 +151,30 @@ class LoRAAdapter(nn.Module):
if "k_proj" in target_module:
weights.pop(k_name)
weights.pop(v_name)
elif "qkv_proj" in weight_name:
# If qkv_proj is already stacked, we normalize it following the SGL convention.
qkv_name = weight_name
q_name = weight_name.replace("qkv_proj", "q_proj")
k_name = weight_name.replace("qkv_proj", "k_proj")
v_name = weight_name.replace("qkv_proj", "v_proj")
kv_name = weight_name.replace("qkv_proj", "kv_proj")
if "lora_A" in weight_name:
weights[qkv_name] = weights[qkv_name].repeat(3, 1)
else:
head_size = (
self.base_hf_config.hidden_size
// self.base_hf_config.num_attention_heads
)
weights[q_name], weights[kv_name] = torch.split(
weights[qkv_name],
[
head_size * self.base_hf_config.num_attention_heads,
head_size * self.base_hf_config.num_key_value_heads * 2,
],
dim=0,
)
def stack_gate_up_proj(
def normalize_gate_up_proj(
self, weight_names: List[str], weights: Dict[str, torch.Tensor]
):
for weight_name in weight_names:
......@@ -179,3 +204,9 @@ class LoRAAdapter(nn.Module):
weights.pop(weight_name)
if up_name in weights:
weights.pop(up_name)
elif "gate_up_proj" in weight_name:
# If gate_up_proj is already stacked, we normalize it following the SGL convention
gate_up_name = weight_name
if "lora_A" in weight_name:
weights[gate_up_name] = weights[gate_up_name].repeat(2, 1)
# else: "lora_B" is already stacked, no operations is needed.
......@@ -32,7 +32,7 @@ from sglang.srt.lora.utils import (
LoRAType,
get_customized_names_from_hf_names,
get_layer_id,
get_stacked_name,
get_normalized_lora_weight_names,
get_weight_name,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
......@@ -101,10 +101,13 @@ class LoRAManager:
self.hf_target_names.update(self.configs[name].target_modules)
# Target lora weight names for lora_a and lora_b modules respectively.
# e.g., {("qkv_proj", "q_proj"), ("qkv_proj", "kv_proj")}
self.lora_weight_names: Set[Tuple[str]] = set(
[get_stacked_name(module) for module in self.hf_target_names]
)
weights_A: List[str] = []
weights_B: List[str] = []
for module in self.hf_target_names:
lora_A, lora_B = get_normalized_lora_weight_names(module)
weights_A += lora_A
weights_B += lora_B
self.lora_weight_names: Tuple[Set[str]] = set(weights_A), set(weights_B)
# load all weights to cpu
self.loras: Dict[str, LoRAAdapter] = {}
......@@ -263,7 +266,18 @@ class LoRAManager:
self.lora_modules: Dict[int, List[Tuple[str, BaseLayerWithLoRA]]] = {
i: [] for i in range(self.base_hf_config.num_hidden_layers)
}
for module_name, module in self.base_model.named_modules():
# TODO (lifuhuang): in the future, we should consider generalizing the
# should_apply_lora function to support mapping by full module name instead
# of just the last part (e.g., "qkv_proj") to support scenarios with multiple
# attention stacks (e.g., multimodal models).
# See: https://github.com/sgl-project/sglang/issues/6608
if getattr(
self.base_model, "should_apply_lora", None
) and not self.base_model.should_apply_lora(module_name):
continue
# The module should be converted if it is included in target_names
if module_name.split(".")[-1] in customized_target_names:
layer_id = get_layer_id(module_name)
......
......@@ -91,18 +91,16 @@ class LoRAMemoryPool:
def init_buffers(
self,
lora_weight_names: Set[Tuple[str]],
lora_weight_names: Tuple[Set[str]],
base_model: torch.nn.Module,
):
# lora_weight_names is a set of name pairs indicating each pair of lora modules to load
# e.g., {("qkv_proj", "q_proj"), ("qkv_proj", "kv_proj"), ("o_proj", "o_proj")}
self.lora_weight_names: Set[Tuple[str]] = lora_weight_names
self.lora_weight_names: Tuple[Set[str]] = lora_weight_names
device = next(base_model.parameters()).device
lora_module_A_names = set([name[0] for name in lora_weight_names])
lora_module_B_names = set([name[1] for name in lora_weight_names])
# Init A tensor, column_major=False
for module_A in lora_module_A_names:
for module_A in lora_weight_names[0]:
lora_A_shape = self.get_lora_A_shape(module_A, base_model)
self.A_buffer[module_A] = [
torch.empty(
......@@ -110,10 +108,10 @@ class LoRAMemoryPool:
dtype=self.dtype,
device=device,
)
for i in range(self.num_layer)
for _ in range(self.num_layer)
]
# Init B tensor, column_major=True
for module_B in lora_module_B_names:
for module_B in lora_weight_names[1]:
lora_B_shape = self.get_lora_B_shape(module_B, base_model)
self.B_buffer[module_B] = [
torch.empty(
......
import re
from dataclasses import dataclass
from enum import Enum
from typing import Optional, Set, Tuple
from typing import List, Optional, Set, Tuple
import torch
......@@ -106,18 +106,22 @@ def get_hidden_dim(
raise NotImplementedError()
def get_stacked_name(name: str) -> Tuple[str]:
def get_normalized_lora_weight_names(name: str) -> Tuple[List[str], List[str]]:
"""
Mapping a target module name to (stacked name for Lora A, stacked name for Lora B)
Mapping a target module name to names of the normized LoRA weights.
Returned tuple contains (name for Lora A, name for Lora B)
"""
params_mapping = {
"q_proj": ("qkv_proj", "q_proj"),
"k_proj": ("qkv_proj", "kv_proj"),
"v_proj": ("qkv_proj", "kv_proj"),
"gate_proj": ("gate_up_proj", "gate_up_proj"),
"up_proj": ("gate_up_proj", "gate_up_proj"),
"q_proj": (["qkv_proj"], ["q_proj"]),
"k_proj": (["qkv_proj"], ["kv_proj"]),
"v_proj": (["qkv_proj"], ["kv_proj"]),
"gate_proj": (["gate_up_proj"], ["gate_up_proj"]),
"up_proj": (["gate_up_proj"], ["gate_up_proj"]),
"qkv_proj": (["qkv_proj"], ["q_proj", "kv_proj"]),
"gate_up_proj": (["gate_up_proj"], ["gate_up_proj"]),
}
return params_mapping.get(name, (name, name))
stacked = params_mapping.get(name, ([name], [name]))
return stacked
def get_stacked_multiply(module_name: str) -> int:
......@@ -133,7 +137,7 @@ def get_stacked_multiply(module_name: str) -> int:
def get_weight_name(
target_name: str, lora_weight_names: Set[Tuple[str]], lora_type: LoRAType
target_name: str, lora_weight_names: Tuple[Set[str]], lora_type: LoRAType
) -> Optional[str]:
"""
target_name is name of a given module,
......@@ -142,9 +146,9 @@ def get_weight_name(
Else raise ValueError.
"""
idx = 0 if lora_type == LoRAType.LORA_A else 1
for weight_name_pair in lora_weight_names:
if weight_name_pair[idx] in target_name:
return weight_name_pair[idx]
for weight_name in lora_weight_names[idx]:
if weight_name in target_name:
return weight_name
raise ValueError(
f"Cannot find weight name for {target_name} in {lora_weight_names}"
)
......
......@@ -17,6 +17,7 @@
import logging
import math
import re
from collections.abc import Iterable
from typing import List, Optional, Tuple
......@@ -392,6 +393,10 @@ class Phi4MMForCausalLM(nn.Module):
"gate_up_proj": ["gate_proj", "up_proj"],
}
lora_pattern = re.compile(
r"^language_model\.model\.layers\.(\d+)\.(?:self_attn|mlp)\.(?:qkv_proj|o_proj|down_proj|gate_up_proj)"
)
def __init__(
self,
config: PretrainedConfig,
......@@ -446,6 +451,9 @@ class Phi4MMForCausalLM(nn.Module):
pattern = MultiModalityDataPaddingPatternMultimodalTokens([im_token_id])
return pattern.pad_input_tokens(input_ids, mm_inputs)
def should_apply_lora(self, module_name: str) -> Optional[str]:
return self.lora_pattern.match(module_name)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
......
......@@ -1473,7 +1473,7 @@ class ServerArgs:
self.max_loras_per_batch > 0
# FIXME
and (self.lora_paths is None or self.disable_radix_cache)
), "compatibility of lora and cuda graph and radix attention is in progress"
), "compatibility of lora and radix attention is in progress"
assert self.base_gpu_id >= 0, "base_gpu_id must be non-negative"
assert self.gpu_id_step >= 1, "gpu_id_step must be positive"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment