"...hubert/git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "576b02b19ec7b8273cc3c343a8d36272b63330ca"
Unverified Commit 477a101c authored by Lifu Huang's avatar Lifu Huang Committed by GitHub
Browse files

Refactor LoRA handling to support adapter tensors in fused format (#6585)

parent 1a8f5f68
...@@ -92,11 +92,12 @@ class LoRAAdapter(nn.Module): ...@@ -92,11 +92,12 @@ class LoRAAdapter(nn.Module):
for i in range(self.base_hf_config.num_hidden_layers): for i in range(self.base_hf_config.num_hidden_layers):
layer = self.layers[i] layer = self.layers[i]
weight_names = [name for name, _ in layer.weights.items()] weight_names = [name for name, _ in layer.weights.items()]
self.stack_qkv_proj(weight_names, layer.weights) self.normalize_qkv_proj(weight_names, layer.weights)
self.stack_gate_up_proj(weight_names, layer.weights) self.normalize_gate_up_proj(weight_names, layer.weights)
def stack_qkv_proj(self, weight_names: List[str], weights: Dict[str, torch.Tensor]):
def normalize_qkv_proj(
self, weight_names: List[str], weights: Dict[str, torch.Tensor]
):
# Collect target q/k/v modules. This process is necessary since there might be no lora attached to k_proj # Collect target q/k/v modules. This process is necessary since there might be no lora attached to k_proj
target_module = set() target_module = set()
for weight_name in weight_names: for weight_name in weight_names:
...@@ -106,6 +107,8 @@ class LoRAAdapter(nn.Module): ...@@ -106,6 +107,8 @@ class LoRAAdapter(nn.Module):
target_module.add("q_proj") target_module.add("q_proj")
if "v_proj" in weight_name: if "v_proj" in weight_name:
target_module.add("v_proj") target_module.add("v_proj")
if "qkv_proj" in weight_name:
target_module.add("qkv_proj")
if len(target_module) == 0: if len(target_module) == 0:
return return
...@@ -148,8 +151,30 @@ class LoRAAdapter(nn.Module): ...@@ -148,8 +151,30 @@ class LoRAAdapter(nn.Module):
if "k_proj" in target_module: if "k_proj" in target_module:
weights.pop(k_name) weights.pop(k_name)
weights.pop(v_name) weights.pop(v_name)
elif "qkv_proj" in weight_name:
# If qkv_proj is already stacked, we normalize it following the SGL convention.
qkv_name = weight_name
q_name = weight_name.replace("qkv_proj", "q_proj")
k_name = weight_name.replace("qkv_proj", "k_proj")
v_name = weight_name.replace("qkv_proj", "v_proj")
kv_name = weight_name.replace("qkv_proj", "kv_proj")
if "lora_A" in weight_name:
weights[qkv_name] = weights[qkv_name].repeat(3, 1)
else:
head_size = (
self.base_hf_config.hidden_size
// self.base_hf_config.num_attention_heads
)
weights[q_name], weights[kv_name] = torch.split(
weights[qkv_name],
[
head_size * self.base_hf_config.num_attention_heads,
head_size * self.base_hf_config.num_key_value_heads * 2,
],
dim=0,
)
def stack_gate_up_proj( def normalize_gate_up_proj(
self, weight_names: List[str], weights: Dict[str, torch.Tensor] self, weight_names: List[str], weights: Dict[str, torch.Tensor]
): ):
for weight_name in weight_names: for weight_name in weight_names:
...@@ -179,3 +204,9 @@ class LoRAAdapter(nn.Module): ...@@ -179,3 +204,9 @@ class LoRAAdapter(nn.Module):
weights.pop(weight_name) weights.pop(weight_name)
if up_name in weights: if up_name in weights:
weights.pop(up_name) weights.pop(up_name)
elif "gate_up_proj" in weight_name:
# If gate_up_proj is already stacked, we normalize it following the SGL convention
gate_up_name = weight_name
if "lora_A" in weight_name:
weights[gate_up_name] = weights[gate_up_name].repeat(2, 1)
# else: "lora_B" is already stacked, no operations is needed.
...@@ -32,7 +32,7 @@ from sglang.srt.lora.utils import ( ...@@ -32,7 +32,7 @@ from sglang.srt.lora.utils import (
LoRAType, LoRAType,
get_customized_names_from_hf_names, get_customized_names_from_hf_names,
get_layer_id, get_layer_id,
get_stacked_name, get_normalized_lora_weight_names,
get_weight_name, get_weight_name,
) )
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
...@@ -101,10 +101,13 @@ class LoRAManager: ...@@ -101,10 +101,13 @@ class LoRAManager:
self.hf_target_names.update(self.configs[name].target_modules) self.hf_target_names.update(self.configs[name].target_modules)
# Target lora weight names for lora_a and lora_b modules respectively. # Target lora weight names for lora_a and lora_b modules respectively.
# e.g., {("qkv_proj", "q_proj"), ("qkv_proj", "kv_proj")} weights_A: List[str] = []
self.lora_weight_names: Set[Tuple[str]] = set( weights_B: List[str] = []
[get_stacked_name(module) for module in self.hf_target_names] for module in self.hf_target_names:
) lora_A, lora_B = get_normalized_lora_weight_names(module)
weights_A += lora_A
weights_B += lora_B
self.lora_weight_names: Tuple[Set[str]] = set(weights_A), set(weights_B)
# load all weights to cpu # load all weights to cpu
self.loras: Dict[str, LoRAAdapter] = {} self.loras: Dict[str, LoRAAdapter] = {}
...@@ -263,7 +266,18 @@ class LoRAManager: ...@@ -263,7 +266,18 @@ class LoRAManager:
self.lora_modules: Dict[int, List[Tuple[str, BaseLayerWithLoRA]]] = { self.lora_modules: Dict[int, List[Tuple[str, BaseLayerWithLoRA]]] = {
i: [] for i in range(self.base_hf_config.num_hidden_layers) i: [] for i in range(self.base_hf_config.num_hidden_layers)
} }
for module_name, module in self.base_model.named_modules(): for module_name, module in self.base_model.named_modules():
# TODO (lifuhuang): in the future, we should consider generalizing the
# should_apply_lora function to support mapping by full module name instead
# of just the last part (e.g., "qkv_proj") to support scenarios with multiple
# attention stacks (e.g., multimodal models).
# See: https://github.com/sgl-project/sglang/issues/6608
if getattr(
self.base_model, "should_apply_lora", None
) and not self.base_model.should_apply_lora(module_name):
continue
# The module should be converted if it is included in target_names # The module should be converted if it is included in target_names
if module_name.split(".")[-1] in customized_target_names: if module_name.split(".")[-1] in customized_target_names:
layer_id = get_layer_id(module_name) layer_id = get_layer_id(module_name)
......
...@@ -91,18 +91,16 @@ class LoRAMemoryPool: ...@@ -91,18 +91,16 @@ class LoRAMemoryPool:
def init_buffers( def init_buffers(
self, self,
lora_weight_names: Set[Tuple[str]], lora_weight_names: Tuple[Set[str]],
base_model: torch.nn.Module, base_model: torch.nn.Module,
): ):
# lora_weight_names is a set of name pairs indicating each pair of lora modules to load # lora_weight_names is a set of name pairs indicating each pair of lora modules to load
# e.g., {("qkv_proj", "q_proj"), ("qkv_proj", "kv_proj"), ("o_proj", "o_proj")} # e.g., {("qkv_proj", "q_proj"), ("qkv_proj", "kv_proj"), ("o_proj", "o_proj")}
self.lora_weight_names: Set[Tuple[str]] = lora_weight_names self.lora_weight_names: Tuple[Set[str]] = lora_weight_names
device = next(base_model.parameters()).device device = next(base_model.parameters()).device
lora_module_A_names = set([name[0] for name in lora_weight_names])
lora_module_B_names = set([name[1] for name in lora_weight_names])
# Init A tensor, column_major=False # Init A tensor, column_major=False
for module_A in lora_module_A_names: for module_A in lora_weight_names[0]:
lora_A_shape = self.get_lora_A_shape(module_A, base_model) lora_A_shape = self.get_lora_A_shape(module_A, base_model)
self.A_buffer[module_A] = [ self.A_buffer[module_A] = [
torch.empty( torch.empty(
...@@ -110,10 +108,10 @@ class LoRAMemoryPool: ...@@ -110,10 +108,10 @@ class LoRAMemoryPool:
dtype=self.dtype, dtype=self.dtype,
device=device, device=device,
) )
for i in range(self.num_layer) for _ in range(self.num_layer)
] ]
# Init B tensor, column_major=True # Init B tensor, column_major=True
for module_B in lora_module_B_names: for module_B in lora_weight_names[1]:
lora_B_shape = self.get_lora_B_shape(module_B, base_model) lora_B_shape = self.get_lora_B_shape(module_B, base_model)
self.B_buffer[module_B] = [ self.B_buffer[module_B] = [
torch.empty( torch.empty(
......
import re import re
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
from typing import Optional, Set, Tuple from typing import List, Optional, Set, Tuple
import torch import torch
...@@ -106,18 +106,22 @@ def get_hidden_dim( ...@@ -106,18 +106,22 @@ def get_hidden_dim(
raise NotImplementedError() raise NotImplementedError()
def get_stacked_name(name: str) -> Tuple[str]: def get_normalized_lora_weight_names(name: str) -> Tuple[List[str], List[str]]:
""" """
Mapping a target module name to (stacked name for Lora A, stacked name for Lora B) Mapping a target module name to names of the normized LoRA weights.
Returned tuple contains (name for Lora A, name for Lora B)
""" """
params_mapping = { params_mapping = {
"q_proj": ("qkv_proj", "q_proj"), "q_proj": (["qkv_proj"], ["q_proj"]),
"k_proj": ("qkv_proj", "kv_proj"), "k_proj": (["qkv_proj"], ["kv_proj"]),
"v_proj": ("qkv_proj", "kv_proj"), "v_proj": (["qkv_proj"], ["kv_proj"]),
"gate_proj": ("gate_up_proj", "gate_up_proj"), "gate_proj": (["gate_up_proj"], ["gate_up_proj"]),
"up_proj": ("gate_up_proj", "gate_up_proj"), "up_proj": (["gate_up_proj"], ["gate_up_proj"]),
"qkv_proj": (["qkv_proj"], ["q_proj", "kv_proj"]),
"gate_up_proj": (["gate_up_proj"], ["gate_up_proj"]),
} }
return params_mapping.get(name, (name, name)) stacked = params_mapping.get(name, ([name], [name]))
return stacked
def get_stacked_multiply(module_name: str) -> int: def get_stacked_multiply(module_name: str) -> int:
...@@ -133,7 +137,7 @@ def get_stacked_multiply(module_name: str) -> int: ...@@ -133,7 +137,7 @@ def get_stacked_multiply(module_name: str) -> int:
def get_weight_name( def get_weight_name(
target_name: str, lora_weight_names: Set[Tuple[str]], lora_type: LoRAType target_name: str, lora_weight_names: Tuple[Set[str]], lora_type: LoRAType
) -> Optional[str]: ) -> Optional[str]:
""" """
target_name is name of a given module, target_name is name of a given module,
...@@ -142,9 +146,9 @@ def get_weight_name( ...@@ -142,9 +146,9 @@ def get_weight_name(
Else raise ValueError. Else raise ValueError.
""" """
idx = 0 if lora_type == LoRAType.LORA_A else 1 idx = 0 if lora_type == LoRAType.LORA_A else 1
for weight_name_pair in lora_weight_names: for weight_name in lora_weight_names[idx]:
if weight_name_pair[idx] in target_name: if weight_name in target_name:
return weight_name_pair[idx] return weight_name
raise ValueError( raise ValueError(
f"Cannot find weight name for {target_name} in {lora_weight_names}" f"Cannot find weight name for {target_name} in {lora_weight_names}"
) )
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
import logging import logging
import math import math
import re
from collections.abc import Iterable from collections.abc import Iterable
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
...@@ -392,6 +393,10 @@ class Phi4MMForCausalLM(nn.Module): ...@@ -392,6 +393,10 @@ class Phi4MMForCausalLM(nn.Module):
"gate_up_proj": ["gate_proj", "up_proj"], "gate_up_proj": ["gate_proj", "up_proj"],
} }
lora_pattern = re.compile(
r"^language_model\.model\.layers\.(\d+)\.(?:self_attn|mlp)\.(?:qkv_proj|o_proj|down_proj|gate_up_proj)"
)
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
...@@ -446,6 +451,9 @@ class Phi4MMForCausalLM(nn.Module): ...@@ -446,6 +451,9 @@ class Phi4MMForCausalLM(nn.Module):
pattern = MultiModalityDataPaddingPatternMultimodalTokens([im_token_id]) pattern = MultiModalityDataPaddingPatternMultimodalTokens([im_token_id])
return pattern.pad_input_tokens(input_ids, mm_inputs) return pattern.pad_input_tokens(input_ids, mm_inputs)
def should_apply_lora(self, module_name: str) -> Optional[str]:
return self.lora_pattern.match(module_name)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
......
...@@ -1473,7 +1473,7 @@ class ServerArgs: ...@@ -1473,7 +1473,7 @@ class ServerArgs:
self.max_loras_per_batch > 0 self.max_loras_per_batch > 0
# FIXME # FIXME
and (self.lora_paths is None or self.disable_radix_cache) and (self.lora_paths is None or self.disable_radix_cache)
), "compatibility of lora and cuda graph and radix attention is in progress" ), "compatibility of lora and radix attention is in progress"
assert self.base_gpu_id >= 0, "base_gpu_id must be non-negative" assert self.base_gpu_id >= 0, "base_gpu_id must be non-negative"
assert self.gpu_id_step >= 1, "gpu_id_step must be positive" assert self.gpu_id_step >= 1, "gpu_id_step must be positive"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment