Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
4b74c3fc
Unverified
Commit
4b74c3fc
authored
Aug 17, 2025
by
Lifu Huang
Committed by
GitHub
Aug 17, 2025
Browse files
[chore] Clean up redundant lora_weight_names concept to simplify code (#9131)
parent
ce3ca9b0
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
55 additions
and
54 deletions
+55
-54
python/sglang/srt/lora/lora_manager.py
python/sglang/srt/lora/lora_manager.py
+17
-18
python/sglang/srt/lora/mem_pool.py
python/sglang/srt/lora/mem_pool.py
+26
-24
python/sglang/srt/lora/utils.py
python/sglang/srt/lora/utils.py
+10
-12
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+2
-0
No files found.
python/sglang/srt/lora/lora_manager.py
View file @
4b74c3fc
...
@@ -32,8 +32,8 @@ from sglang.srt.lora.utils import (
...
@@ -32,8 +32,8 @@ from sglang.srt.lora.utils import (
LoRABatchInfo
,
LoRABatchInfo
,
LoRAType
,
LoRAType
,
get_layer_id
,
get_layer_id
,
get_normalized_
lora_weight_nam
es
,
get_normalized_
target_modul
es
,
get_
weight
_name
,
get_
target_module
_name
,
)
)
from
sglang.srt.managers.io_struct
import
LoRAUpdateResult
from
sglang.srt.managers.io_struct
import
LoRAUpdateResult
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
...
@@ -350,12 +350,20 @@ class LoRAManager:
...
@@ -350,12 +350,20 @@ class LoRAManager:
"""
"""
for
layer_id
,
layer_modules
in
enumerate
(
self
.
lora_modules
):
for
layer_id
,
layer_modules
in
enumerate
(
self
.
lora_modules
):
for
module_name
,
module
in
layer_modules
.
items
():
for
module_name
,
module
in
layer_modules
.
items
():
weight_name
=
get_weight
_name
(
target_module
=
get_target_module
_name
(
module_name
,
self
.
memory_pool
.
lora_weight_nam
es
module_name
,
self
.
memory_pool
.
target_modul
es
)
)
module
.
set_lora_info
(
module
.
set_lora_info
(
self
.
memory_pool
.
get_tensor
(
weight_name
,
layer_id
,
LoRAType
.
LORA_A
),
self
.
memory_pool
.
get_tensor
(
self
.
memory_pool
.
get_tensor
(
weight_name
,
layer_id
,
LoRAType
.
LORA_B
),
target_module
=
target_module
,
layer_id
=
layer_id
,
lora_type
=
LoRAType
.
LORA_A
,
),
self
.
memory_pool
.
get_tensor
(
target_module
=
target_module
,
layer_id
=
layer_id
,
lora_type
=
LoRAType
.
LORA_B
,
),
)
)
def
init_state
(
def
init_state
(
...
@@ -380,7 +388,6 @@ class LoRAManager:
...
@@ -380,7 +388,6 @@ class LoRAManager:
max_lora_rank
=
max_lora_rank
,
max_lora_rank
=
max_lora_rank
,
target_modules
=
target_modules
,
target_modules
=
target_modules
,
)
)
self
.
init_lora_weight_names
()
self
.
init_lora_modules
()
self
.
init_lora_modules
()
self
.
init_memory_pool
()
self
.
init_memory_pool
()
self
.
update_lora_info
()
self
.
update_lora_info
()
...
@@ -426,6 +433,7 @@ class LoRAManager:
...
@@ -426,6 +433,7 @@ class LoRAManager:
"enable all support modules types. "
"enable all support modules types. "
)
)
self
.
target_modules
.
update
(
config
.
target_modules
)
self
.
target_modules
.
update
(
config
.
target_modules
)
self
.
target_modules
=
get_normalized_target_modules
(
self
.
target_modules
)
if
max_lora_rank
is
not
None
:
if
max_lora_rank
is
not
None
:
self
.
max_lora_rank
=
max_lora_rank
self
.
max_lora_rank
=
max_lora_rank
...
@@ -435,15 +443,6 @@ class LoRAManager:
...
@@ -435,15 +443,6 @@ class LoRAManager:
default
=
0
,
default
=
0
,
)
)
def
init_lora_weight_names
(
self
):
"""
Add new LoRA weight names if needed based on the current `self.configs`.
"""
self
.
lora_weight_names
:
Set
[
str
]
=
get_normalized_lora_weight_names
(
self
.
target_modules
)
def
load_lora_weights
(
self
,
lora_ref
:
LoRARef
):
def
load_lora_weights
(
self
,
lora_ref
:
LoRARef
):
"""
"""
Load the weights of a LoRA adapter to CPU memory and conducts post-loading validation.
Load the weights of a LoRA adapter to CPU memory and conducts post-loading validation.
...
@@ -467,7 +466,7 @@ class LoRAManager:
...
@@ -467,7 +466,7 @@ class LoRAManager:
tp_size
=
self
.
tp_size
,
tp_size
=
self
.
tp_size
,
tp_rank
=
self
.
tp_rank
,
tp_rank
=
self
.
tp_rank
,
max_lora_rank
=
self
.
max_lora_rank
,
max_lora_rank
=
self
.
max_lora_rank
,
lora_weight_names
=
self
.
lora_weight_nam
es
,
target_modules
=
self
.
target_modul
es
,
base_model
=
self
.
base_model
,
base_model
=
self
.
base_model
,
)
)
...
@@ -494,7 +493,7 @@ class LoRAManager:
...
@@ -494,7 +493,7 @@ class LoRAManager:
continue
continue
# The module should be converted if it is included in target_names
# The module should be converted if it is included in target_names
if
module_name
.
split
(
"."
)[
-
1
]
in
self
.
lora_weight_nam
es
:
if
module_name
.
split
(
"."
)[
-
1
]
in
self
.
target_modul
es
:
layer_id
=
get_layer_id
(
module_name
)
layer_id
=
get_layer_id
(
module_name
)
self
.
lora_modules
[
layer_id
][
module_name
]
=
self
.
set_lora_module
(
self
.
lora_modules
[
layer_id
][
module_name
]
=
self
.
set_lora_module
(
module_name
,
module
module_name
,
module
...
...
python/sglang/srt/lora/mem_pool.py
View file @
4b74c3fc
...
@@ -13,9 +13,9 @@ from sglang.srt.lora.utils import (
...
@@ -13,9 +13,9 @@ from sglang.srt.lora.utils import (
ROW_PARALLELISM_LINEAR_LORA_NAMES
,
ROW_PARALLELISM_LINEAR_LORA_NAMES
,
LoRAType
,
LoRAType
,
get_hidden_dim
,
get_hidden_dim
,
get_normalized_
lora_weight_nam
es
,
get_normalized_
target_modul
es
,
get_stacked_multiply
,
get_stacked_multiply
,
get_
weight
_name
,
get_
target_module
_name
,
)
)
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -52,7 +52,7 @@ class LoRAMemoryPool:
...
@@ -52,7 +52,7 @@ class LoRAMemoryPool:
tp_size
:
int
,
tp_size
:
int
,
tp_rank
:
int
,
tp_rank
:
int
,
max_lora_rank
:
int
,
max_lora_rank
:
int
,
lora_weight_nam
es
:
Set
[
str
],
target_modul
es
:
Set
[
str
],
base_model
:
torch
.
nn
.
Module
,
base_model
:
torch
.
nn
.
Module
,
):
):
self
.
base_hf_config
:
AutoConfig
=
base_hf_config
self
.
base_hf_config
:
AutoConfig
=
base_hf_config
...
@@ -62,7 +62,7 @@ class LoRAMemoryPool:
...
@@ -62,7 +62,7 @@ class LoRAMemoryPool:
self
.
tp_size
:
int
=
tp_size
self
.
tp_size
:
int
=
tp_size
self
.
tp_rank
:
int
=
tp_rank
self
.
tp_rank
:
int
=
tp_rank
self
.
max_lora_rank
:
int
=
max_lora_rank
self
.
max_lora_rank
:
int
=
max_lora_rank
self
.
lora_weight_names
:
Set
[
str
]
=
lora_weight_nam
es
self
.
target_modules
:
Set
[
str
]
=
target_modul
es
# Both A_buffer and B_buffer maps lora weight names to its buffer space.
# Both A_buffer and B_buffer maps lora weight names to its buffer space.
# A_buffer contains num_layer number of row-major tensors with shape
# A_buffer contains num_layer number of row-major tensors with shape
...
@@ -95,8 +95,8 @@ class LoRAMemoryPool:
...
@@ -95,8 +95,8 @@ class LoRAMemoryPool:
"""
"""
if
config
.
r
>
self
.
max_lora_rank
:
if
config
.
r
>
self
.
max_lora_rank
:
return
False
return
False
weight
s
=
get_normalized_
lora_weight_nam
es
(
config
.
target_modules
)
target_module_name
s
=
get_normalized_
target_modul
es
(
config
.
target_modules
)
return
weight
s
.
issubset
(
self
.
lora_weight_nam
es
)
return
target_module_name
s
.
issubset
(
self
.
target_modul
es
)
if
isinstance
(
config
,
LoRAConfig
):
if
isinstance
(
config
,
LoRAConfig
):
return
_can_support
(
config
)
return
_can_support
(
config
)
...
@@ -139,10 +139,10 @@ class LoRAMemoryPool:
...
@@ -139,10 +139,10 @@ class LoRAMemoryPool:
def
init_buffer
(
def
init_buffer
(
buffer
:
Dict
[
str
,
List
[
torch
.
Tensor
]],
buffer
:
Dict
[
str
,
List
[
torch
.
Tensor
]],
lora_weight_nam
es
:
Set
[
str
],
target_modul
es
:
Set
[
str
],
get_lora_shape_fn
:
Callable
[[
str
,
torch
.
nn
.
Module
,
int
],
Tuple
[
int
]],
get_lora_shape_fn
:
Callable
[[
str
,
torch
.
nn
.
Module
,
int
],
Tuple
[
int
]],
):
):
for
module_name
in
lora_weight_nam
es
:
for
module_name
in
target_modul
es
:
lora_shape
=
get_lora_shape_fn
(
lora_shape
=
get_lora_shape_fn
(
module_name
,
base_model
,
self
.
max_lora_rank
module_name
,
base_model
,
self
.
max_lora_rank
)
)
...
@@ -157,13 +157,13 @@ class LoRAMemoryPool:
...
@@ -157,13 +157,13 @@ class LoRAMemoryPool:
init_buffer
(
init_buffer
(
self
.
A_buffer
,
self
.
A_buffer
,
self
.
lora_weight_nam
es
,
self
.
target_modul
es
,
self
.
get_lora_A_shape
,
self
.
get_lora_A_shape
,
)
)
init_buffer
(
init_buffer
(
self
.
B_buffer
,
self
.
B_buffer
,
self
.
lora_weight_nam
es
,
self
.
target_modul
es
,
self
.
get_lora_B_shape
,
self
.
get_lora_B_shape
,
)
)
...
@@ -242,32 +242,34 @@ class LoRAMemoryPool:
...
@@ -242,32 +242,34 @@ class LoRAMemoryPool:
for
layer_id
in
range
(
self
.
num_layer
):
for
layer_id
in
range
(
self
.
num_layer
):
layer_weights
=
lora_adapter
.
layers
[
layer_id
].
weights
layer_weights
=
lora_adapter
.
layers
[
layer_id
].
weights
temp_A_buffer
:
Dict
[
str
,
Optional
[
torch
.
Tensor
]]
=
{
temp_A_buffer
:
Dict
[
str
,
Optional
[
torch
.
Tensor
]]
=
{
weight_name
:
None
for
weight_nam
e
in
self
.
A_buffer
target_module
:
None
for
target_modul
e
in
self
.
A_buffer
}
}
temp_B_buffer
:
Dict
[
str
,
Optional
[
torch
.
Tensor
]]
=
{
temp_B_buffer
:
Dict
[
str
,
Optional
[
torch
.
Tensor
]]
=
{
weight_name
:
None
for
weight_nam
e
in
self
.
B_buffer
target_module
:
None
for
target_modul
e
in
self
.
B_buffer
}
}
for
name
,
weights
in
layer_weights
.
items
():
for
name
,
weights
in
layer_weights
.
items
():
lora_weight_name
=
get_weight
_name
(
name
,
self
.
lora_weight_nam
es
)
target_module
=
get_target_module
_name
(
name
,
self
.
target_modul
es
)
if
"lora_A"
in
name
:
if
"lora_A"
in
name
:
temp_A_buffer
[
lora_weight_nam
e
]
=
weights
temp_A_buffer
[
target_modul
e
]
=
weights
else
:
else
:
temp_B_buffer
[
lora_weight_nam
e
]
=
weights
temp_B_buffer
[
target_modul
e
]
=
weights
if
self
.
tp_size
>
1
:
if
self
.
tp_size
>
1
:
cur_layer_modules
=
lora_modules
[
layer_id
]
cur_layer_modules
=
lora_modules
[
layer_id
]
for
module_name
,
module
in
cur_layer_modules
.
items
():
for
module_name
,
module
in
cur_layer_modules
.
items
():
weight_name
=
get_weight_name
(
module_name
,
self
.
lora_weight_names
)
target_module
=
get_target_module_name
(
module_name
,
self
.
target_modules
)
if
temp_A_buffer
[
weight_nam
e
]
is
None
:
if
temp_A_buffer
[
target_modul
e
]
is
None
:
# Skip weight slicing if the weight is not present in the adapter
# Skip weight slicing if the weight is not present in the adapter
continue
continue
temp_A_buffer
[
weight_nam
e
]
=
module
.
slice_lora_a_weights
(
temp_A_buffer
[
target_modul
e
]
=
module
.
slice_lora_a_weights
(
temp_A_buffer
[
weight_nam
e
],
self
.
tp_rank
temp_A_buffer
[
target_modul
e
],
self
.
tp_rank
)
)
temp_B_buffer
[
weight_nam
e
]
=
module
.
slice_lora_b_weights
(
temp_B_buffer
[
target_modul
e
]
=
module
.
slice_lora_b_weights
(
temp_B_buffer
[
weight_nam
e
],
self
.
tp_rank
temp_B_buffer
[
target_modul
e
],
self
.
tp_rank
)
)
for
name
,
weights
in
temp_A_buffer
.
items
():
for
name
,
weights
in
temp_A_buffer
.
items
():
...
@@ -282,12 +284,12 @@ class LoRAMemoryPool:
...
@@ -282,12 +284,12 @@ class LoRAMemoryPool:
load_lora_weight_tensor
(
buffer_view
,
weights
)
load_lora_weight_tensor
(
buffer_view
,
weights
)
def
get_tensor
(
def
get_tensor
(
self
,
weight_nam
e
:
str
,
layer_id
:
int
,
lora_type
:
LoRAType
self
,
target_modul
e
:
str
,
layer_id
:
int
,
lora_type
:
LoRAType
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
lora_type
==
LoRAType
.
LORA_A
:
if
lora_type
==
LoRAType
.
LORA_A
:
return
self
.
A_buffer
[
weight_nam
e
][
layer_id
]
return
self
.
A_buffer
[
target_modul
e
][
layer_id
]
return
self
.
B_buffer
[
weight_nam
e
][
layer_id
]
return
self
.
B_buffer
[
target_modul
e
][
layer_id
]
def
get_buffer_id
(
self
,
lora_uid
:
str
):
def
get_buffer_id
(
self
,
lora_uid
:
str
):
return
self
.
uid_to_buffer_id
[
lora_uid
]
return
self
.
uid_to_buffer_id
[
lora_uid
]
python/sglang/srt/lora/utils.py
View file @
4b74c3fc
...
@@ -84,7 +84,7 @@ def get_hidden_dim(
...
@@ -84,7 +84,7 @@ def get_hidden_dim(
raise
NotImplementedError
()
raise
NotImplementedError
()
def
get_normalized_
lora_weight_nam
es
(
def
get_normalized_
target_modul
es
(
target_modules
:
Iterable
[
str
],
target_modules
:
Iterable
[
str
],
)
->
set
[
str
]:
)
->
set
[
str
]:
"""
"""
...
@@ -100,8 +100,8 @@ def get_normalized_lora_weight_names(
...
@@ -100,8 +100,8 @@ def get_normalized_lora_weight_names(
result
=
set
()
result
=
set
()
for
name
in
target_modules
:
for
name
in
target_modules
:
weight
_name
=
params_mapping
.
get
(
name
,
name
)
normalized
_name
=
params_mapping
.
get
(
name
,
name
)
result
.
add
(
weight
_name
)
result
.
add
(
normalized
_name
)
return
result
return
result
...
@@ -116,20 +116,18 @@ def get_stacked_multiply(module_name: str) -> int:
...
@@ -116,20 +116,18 @@ def get_stacked_multiply(module_name: str) -> int:
return
stacked_rank
[
module_name
]
if
module_name
in
stacked_rank
else
1
return
stacked_rank
[
module_name
]
if
module_name
in
stacked_rank
else
1
def
get_weight_name
(
def
get_target_module_name
(
full_module_name
:
str
,
target_modules
:
Set
[
str
])
->
str
:
target_name
:
str
,
lora_weight_names
:
Tuple
[
Set
[
str
]]
)
->
Optional
[
str
]:
"""
"""
Get the
weight name in lora_weight_nam
es that can match
target
_name.
Get the
target module name in target_modul
es that can match
full_module
_name.
If there is a
weight name in lora_weight_nam
es that can match
target
_name, return this name
If there is a
target module name in target_modul
es that can match
full_module
_name, return this name
Else raise ValueError.
Else raise ValueError.
"""
"""
for
weight_name
in
lora_weight_nam
es
:
for
target_module
in
target_modul
es
:
if
weight_name
in
target
_name
:
if
target_module
in
full_module
_name
:
return
weight_nam
e
return
target_modul
e
raise
ValueError
(
raise
ValueError
(
f
"Cannot find
weight name for
{
target_name
}
in
{
lora_weight_nam
es
}
"
f
"Cannot find
target module name for
{
full_module_name
}
in
{
target_modul
es
}
"
)
)
...
...
python/sglang/srt/utils.py
View file @
4b74c3fc
...
@@ -2874,6 +2874,8 @@ SUPPORTED_LORA_TARGET_MODULES = [
...
@@ -2874,6 +2874,8 @@ SUPPORTED_LORA_TARGET_MODULES = [
"gate_proj"
,
"gate_proj"
,
"up_proj"
,
"up_proj"
,
"down_proj"
,
"down_proj"
,
"qkv_proj"
,
"gate_up_proj"
,
]
]
LORA_TARGET_ALL_MODULES
=
"all"
LORA_TARGET_ALL_MODULES
=
"all"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment