Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
477a101c
Unverified
Commit
477a101c
authored
May 26, 2025
by
Lifu Huang
Committed by
GitHub
May 26, 2025
Browse files
Refactor LoRA handling to support adapter tensors in fused format (#6585)
parent
1a8f5f68
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
86 additions
and
31 deletions
+86
-31
python/sglang/srt/lora/lora.py
python/sglang/srt/lora/lora.py
+36
-5
python/sglang/srt/lora/lora_manager.py
python/sglang/srt/lora/lora_manager.py
+19
-5
python/sglang/srt/lora/mem_pool.py
python/sglang/srt/lora/mem_pool.py
+5
-7
python/sglang/srt/lora/utils.py
python/sglang/srt/lora/utils.py
+17
-13
python/sglang/srt/models/phi4mm.py
python/sglang/srt/models/phi4mm.py
+8
-0
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+1
-1
No files found.
python/sglang/srt/lora/lora.py
View file @
477a101c
...
...
@@ -92,11 +92,12 @@ class LoRAAdapter(nn.Module):
for
i
in
range
(
self
.
base_hf_config
.
num_hidden_layers
):
layer
=
self
.
layers
[
i
]
weight_names
=
[
name
for
name
,
_
in
layer
.
weights
.
items
()]
self
.
stack_qkv_proj
(
weight_names
,
layer
.
weights
)
self
.
stack_gate_up_proj
(
weight_names
,
layer
.
weights
)
def
stack_qkv_proj
(
self
,
weight_names
:
List
[
str
],
weights
:
Dict
[
str
,
torch
.
Tensor
]):
self
.
normalize_qkv_proj
(
weight_names
,
layer
.
weights
)
self
.
normalize_gate_up_proj
(
weight_names
,
layer
.
weights
)
def
normalize_qkv_proj
(
self
,
weight_names
:
List
[
str
],
weights
:
Dict
[
str
,
torch
.
Tensor
]
):
# Collect target q/k/v modules. This process is necessary since there might be no lora attached to k_proj
target_module
=
set
()
for
weight_name
in
weight_names
:
...
...
@@ -106,6 +107,8 @@ class LoRAAdapter(nn.Module):
target_module
.
add
(
"q_proj"
)
if
"v_proj"
in
weight_name
:
target_module
.
add
(
"v_proj"
)
if
"qkv_proj"
in
weight_name
:
target_module
.
add
(
"qkv_proj"
)
if
len
(
target_module
)
==
0
:
return
...
...
@@ -148,8 +151,30 @@ class LoRAAdapter(nn.Module):
if
"k_proj"
in
target_module
:
weights
.
pop
(
k_name
)
weights
.
pop
(
v_name
)
elif
"qkv_proj"
in
weight_name
:
# If qkv_proj is already stacked, we normalize it following the SGL convention.
qkv_name
=
weight_name
q_name
=
weight_name
.
replace
(
"qkv_proj"
,
"q_proj"
)
k_name
=
weight_name
.
replace
(
"qkv_proj"
,
"k_proj"
)
v_name
=
weight_name
.
replace
(
"qkv_proj"
,
"v_proj"
)
kv_name
=
weight_name
.
replace
(
"qkv_proj"
,
"kv_proj"
)
if
"lora_A"
in
weight_name
:
weights
[
qkv_name
]
=
weights
[
qkv_name
].
repeat
(
3
,
1
)
else
:
head_size
=
(
self
.
base_hf_config
.
hidden_size
//
self
.
base_hf_config
.
num_attention_heads
)
weights
[
q_name
],
weights
[
kv_name
]
=
torch
.
split
(
weights
[
qkv_name
],
[
head_size
*
self
.
base_hf_config
.
num_attention_heads
,
head_size
*
self
.
base_hf_config
.
num_key_value_heads
*
2
,
],
dim
=
0
,
)
def
stack
_gate_up_proj
(
def
normalize
_gate_up_proj
(
self
,
weight_names
:
List
[
str
],
weights
:
Dict
[
str
,
torch
.
Tensor
]
):
for
weight_name
in
weight_names
:
...
...
@@ -179,3 +204,9 @@ class LoRAAdapter(nn.Module):
weights
.
pop
(
weight_name
)
if
up_name
in
weights
:
weights
.
pop
(
up_name
)
elif
"gate_up_proj"
in
weight_name
:
# If gate_up_proj is already stacked, we normalize it following the SGL convention
gate_up_name
=
weight_name
if
"lora_A"
in
weight_name
:
weights
[
gate_up_name
]
=
weights
[
gate_up_name
].
repeat
(
2
,
1
)
# else: "lora_B" is already stacked, no operations is needed.
python/sglang/srt/lora/lora_manager.py
View file @
477a101c
...
...
@@ -32,7 +32,7 @@ from sglang.srt.lora.utils import (
LoRAType
,
get_customized_names_from_hf_names
,
get_layer_id
,
get_
stacked
_name
,
get_
normalized_lora_weight
_name
s
,
get_weight_name
,
)
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
...
...
@@ -101,10 +101,13 @@ class LoRAManager:
self
.
hf_target_names
.
update
(
self
.
configs
[
name
].
target_modules
)
# Target lora weight names for lora_a and lora_b modules respectively.
# e.g., {("qkv_proj", "q_proj"), ("qkv_proj", "kv_proj")}
self
.
lora_weight_names
:
Set
[
Tuple
[
str
]]
=
set
(
[
get_stacked_name
(
module
)
for
module
in
self
.
hf_target_names
]
)
weights_A
:
List
[
str
]
=
[]
weights_B
:
List
[
str
]
=
[]
for
module
in
self
.
hf_target_names
:
lora_A
,
lora_B
=
get_normalized_lora_weight_names
(
module
)
weights_A
+=
lora_A
weights_B
+=
lora_B
self
.
lora_weight_names
:
Tuple
[
Set
[
str
]]
=
set
(
weights_A
),
set
(
weights_B
)
# load all weights to cpu
self
.
loras
:
Dict
[
str
,
LoRAAdapter
]
=
{}
...
...
@@ -263,7 +266,18 @@ class LoRAManager:
self
.
lora_modules
:
Dict
[
int
,
List
[
Tuple
[
str
,
BaseLayerWithLoRA
]]]
=
{
i
:
[]
for
i
in
range
(
self
.
base_hf_config
.
num_hidden_layers
)
}
for
module_name
,
module
in
self
.
base_model
.
named_modules
():
# TODO (lifuhuang): in the future, we should consider generalizing the
# should_apply_lora function to support mapping by full module name instead
# of just the last part (e.g., "qkv_proj") to support scenarios with multiple
# attention stacks (e.g., multimodal models).
# See: https://github.com/sgl-project/sglang/issues/6608
if
getattr
(
self
.
base_model
,
"should_apply_lora"
,
None
)
and
not
self
.
base_model
.
should_apply_lora
(
module_name
):
continue
# The module should be converted if it is included in target_names
if
module_name
.
split
(
"."
)[
-
1
]
in
customized_target_names
:
layer_id
=
get_layer_id
(
module_name
)
...
...
python/sglang/srt/lora/mem_pool.py
View file @
477a101c
...
...
@@ -91,18 +91,16 @@ class LoRAMemoryPool:
def
init_buffers
(
self
,
lora_weight_names
:
Set
[
Tuple
[
str
]],
lora_weight_names
:
Tuple
[
Set
[
str
]],
base_model
:
torch
.
nn
.
Module
,
):
# lora_weight_names is a set of name pairs indicating each pair of lora modules to load
# e.g., {("qkv_proj", "q_proj"), ("qkv_proj", "kv_proj"), ("o_proj", "o_proj")}
self
.
lora_weight_names
:
Set
[
Tuple
[
str
]]
=
lora_weight_names
self
.
lora_weight_names
:
Tuple
[
Set
[
str
]]
=
lora_weight_names
device
=
next
(
base_model
.
parameters
()).
device
lora_module_A_names
=
set
([
name
[
0
]
for
name
in
lora_weight_names
])
lora_module_B_names
=
set
([
name
[
1
]
for
name
in
lora_weight_names
])
# Init A tensor, column_major=False
for
module_A
in
lora_
module_A
_names
:
for
module_A
in
lora_
weight
_names
[
0
]
:
lora_A_shape
=
self
.
get_lora_A_shape
(
module_A
,
base_model
)
self
.
A_buffer
[
module_A
]
=
[
torch
.
empty
(
...
...
@@ -110,10 +108,10 @@ class LoRAMemoryPool:
dtype
=
self
.
dtype
,
device
=
device
,
)
for
i
in
range
(
self
.
num_layer
)
for
_
in
range
(
self
.
num_layer
)
]
# Init B tensor, column_major=True
for
module_B
in
lora_
module_B
_names
:
for
module_B
in
lora_
weight
_names
[
1
]
:
lora_B_shape
=
self
.
get_lora_B_shape
(
module_B
,
base_model
)
self
.
B_buffer
[
module_B
]
=
[
torch
.
empty
(
...
...
python/sglang/srt/lora/utils.py
View file @
477a101c
import
re
from
dataclasses
import
dataclass
from
enum
import
Enum
from
typing
import
Optional
,
Set
,
Tuple
from
typing
import
List
,
Optional
,
Set
,
Tuple
import
torch
...
...
@@ -106,18 +106,22 @@ def get_hidden_dim(
raise
NotImplementedError
()
def
get_
stacked
_name
(
name
:
str
)
->
Tuple
[
str
]:
def
get_
normalized_lora_weight
_name
s
(
name
:
str
)
->
Tuple
[
List
[
str
],
List
[
str
]
]:
"""
Mapping a target module name to (stacked name for Lora A, stacked name for Lora B)
Mapping a target module name to names of the normized LoRA weights.
Returned tuple contains (name for Lora A, name for Lora B)
"""
params_mapping
=
{
"q_proj"
:
(
"qkv_proj"
,
"q_proj"
),
"k_proj"
:
(
"qkv_proj"
,
"kv_proj"
),
"v_proj"
:
(
"qkv_proj"
,
"kv_proj"
),
"gate_proj"
:
(
"gate_up_proj"
,
"gate_up_proj"
),
"up_proj"
:
(
"gate_up_proj"
,
"gate_up_proj"
),
"q_proj"
:
([
"qkv_proj"
],
[
"q_proj"
]),
"k_proj"
:
([
"qkv_proj"
],
[
"kv_proj"
]),
"v_proj"
:
([
"qkv_proj"
],
[
"kv_proj"
]),
"gate_proj"
:
([
"gate_up_proj"
],
[
"gate_up_proj"
]),
"up_proj"
:
([
"gate_up_proj"
],
[
"gate_up_proj"
]),
"qkv_proj"
:
([
"qkv_proj"
],
[
"q_proj"
,
"kv_proj"
]),
"gate_up_proj"
:
([
"gate_up_proj"
],
[
"gate_up_proj"
]),
}
return
params_mapping
.
get
(
name
,
(
name
,
name
))
stacked
=
params_mapping
.
get
(
name
,
([
name
],
[
name
]))
return
stacked
def
get_stacked_multiply
(
module_name
:
str
)
->
int
:
...
...
@@ -133,7 +137,7 @@ def get_stacked_multiply(module_name: str) -> int:
def
get_weight_name
(
target_name
:
str
,
lora_weight_names
:
Set
[
Tuple
[
str
]],
lora_type
:
LoRAType
target_name
:
str
,
lora_weight_names
:
Tuple
[
Set
[
str
]],
lora_type
:
LoRAType
)
->
Optional
[
str
]:
"""
target_name is name of a given module,
...
...
@@ -142,9 +146,9 @@ def get_weight_name(
Else raise ValueError.
"""
idx
=
0
if
lora_type
==
LoRAType
.
LORA_A
else
1
for
weight_name
_pair
in
lora_weight_names
:
if
weight_name
_pair
[
idx
]
in
target_name
:
return
weight_name
_pair
[
idx
]
for
weight_name
in
lora_weight_names
[
idx
]
:
if
weight_name
in
target_name
:
return
weight_name
raise
ValueError
(
f
"Cannot find weight name for
{
target_name
}
in
{
lora_weight_names
}
"
)
...
...
python/sglang/srt/models/phi4mm.py
View file @
477a101c
...
...
@@ -17,6 +17,7 @@
import
logging
import
math
import
re
from
collections.abc
import
Iterable
from
typing
import
List
,
Optional
,
Tuple
...
...
@@ -392,6 +393,10 @@ class Phi4MMForCausalLM(nn.Module):
"gate_up_proj"
:
[
"gate_proj"
,
"up_proj"
],
}
lora_pattern
=
re
.
compile
(
r
"^language_model\.model\.layers\.(\d+)\.(?:self_attn|mlp)\.(?:qkv_proj|o_proj|down_proj|gate_up_proj)"
)
def
__init__
(
self
,
config
:
PretrainedConfig
,
...
...
@@ -446,6 +451,9 @@ class Phi4MMForCausalLM(nn.Module):
pattern
=
MultiModalityDataPaddingPatternMultimodalTokens
([
im_token_id
])
return
pattern
.
pad_input_tokens
(
input_ids
,
mm_inputs
)
def
should_apply_lora
(
self
,
module_name
:
str
)
->
Optional
[
str
]:
return
self
.
lora_pattern
.
match
(
module_name
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
...
...
python/sglang/srt/server_args.py
View file @
477a101c
...
...
@@ -1473,7 +1473,7 @@ class ServerArgs:
self
.
max_loras_per_batch
>
0
# FIXME
and
(
self
.
lora_paths
is
None
or
self
.
disable_radix_cache
)
),
"compatibility of lora and
cuda graph and
radix attention is in progress"
),
"compatibility of lora and radix attention is in progress"
assert
self
.
base_gpu_id
>=
0
,
"base_gpu_id must be non-negative"
assert
self
.
gpu_id_step
>=
1
,
"gpu_id_step must be positive"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment