Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
8abd3e77
Unverified
Commit
8abd3e77
authored
Jul 23, 2025
by
Lifu Huang
Committed by
GitHub
Jul 23, 2025
Browse files
Introduce Stable LoRA ID System for Overlapped Updates and Prefix Caching (#8261)
parent
e885bfdc
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
399 additions
and
260 deletions
+399
-260
python/sglang/srt/lora/lora_manager.py
python/sglang/srt/lora/lora_manager.py
+133
-169
python/sglang/srt/lora/lora_registry.py
python/sglang/srt/lora/lora_registry.py
+124
-0
python/sglang/srt/lora/mem_pool.py
python/sglang/srt/lora/mem_pool.py
+2
-2
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+19
-1
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+3
-17
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+28
-25
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+2
-4
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+10
-15
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+19
-4
test/srt/models/lora/test_lora_eviction.py
test/srt/models/lora/test_lora_eviction.py
+58
-22
test/srt/run_suite.py
test/srt/run_suite.py
+1
-1
No files found.
python/sglang/srt/lora/lora_manager.py
View file @
8abd3e77
...
@@ -16,7 +16,7 @@
...
@@ -16,7 +16,7 @@
# and "Punica: Multi-Tenant LoRA Serving"
# and "Punica: Multi-Tenant LoRA Serving"
import
logging
import
logging
from
typing
import
Dict
,
Iterable
,
Optional
,
Set
,
Tuple
from
typing
import
Dict
,
Iterable
,
List
,
Optional
,
Set
,
Tuple
import
torch
import
torch
...
@@ -26,6 +26,7 @@ from sglang.srt.lora.backend.base_backend import BaseLoRABackend, get_backend_fr
...
@@ -26,6 +26,7 @@ from sglang.srt.lora.backend.base_backend import BaseLoRABackend, get_backend_fr
from
sglang.srt.lora.layers
import
BaseLayerWithLoRA
,
get_lora_layer
from
sglang.srt.lora.layers
import
BaseLayerWithLoRA
,
get_lora_layer
from
sglang.srt.lora.lora
import
LoRAAdapter
from
sglang.srt.lora.lora
import
LoRAAdapter
from
sglang.srt.lora.lora_config
import
LoRAConfig
from
sglang.srt.lora.lora_config
import
LoRAConfig
from
sglang.srt.lora.lora_registry
import
LoRARef
from
sglang.srt.lora.mem_pool
import
LoRAMemoryPool
from
sglang.srt.lora.mem_pool
import
LoRAMemoryPool
from
sglang.srt.lora.utils
import
(
from
sglang.srt.lora.utils
import
(
LoRABatchInfo
,
LoRABatchInfo
,
...
@@ -55,6 +56,7 @@ class LoRAManager:
...
@@ -55,6 +56,7 @@ class LoRAManager:
tp_rank
:
int
=
0
,
tp_rank
:
int
=
0
,
max_lora_rank
:
Optional
[
int
]
=
None
,
max_lora_rank
:
Optional
[
int
]
=
None
,
target_modules
:
Optional
[
Iterable
[
str
]]
=
None
,
target_modules
:
Optional
[
Iterable
[
str
]]
=
None
,
lora_paths
:
Optional
[
Dict
[
str
,
LoRARef
]]
=
None
,
):
):
self
.
base_model
:
torch
.
nn
.
Module
=
base_model
self
.
base_model
:
torch
.
nn
.
Module
=
base_model
self
.
base_hf_config
:
AutoConfig
=
base_hf_config
self
.
base_hf_config
:
AutoConfig
=
base_hf_config
...
@@ -64,10 +66,6 @@ class LoRAManager:
...
@@ -64,10 +66,6 @@ class LoRAManager:
self
.
device
:
torch
.
device
=
next
(
self
.
base_model
.
parameters
()).
device
self
.
device
:
torch
.
device
=
next
(
self
.
base_model
.
parameters
()).
device
self
.
tp_size
:
int
=
tp_size
self
.
tp_size
:
int
=
tp_size
self
.
tp_rank
:
int
=
tp_rank
self
.
tp_rank
:
int
=
tp_rank
self
.
max_lora_rank
:
Optional
[
int
]
=
max_lora_rank
self
.
target_modules
:
Optional
[
Set
[
str
]]
=
(
set
(
target_modules
)
if
target_modules
else
None
)
# LoRA backend for running sgemm kernels
# LoRA backend for running sgemm kernels
logger
.
info
(
f
"Using
{
lora_backend
}
as backend of LoRA kernels."
)
logger
.
info
(
f
"Using
{
lora_backend
}
as backend of LoRA kernels."
)
...
@@ -75,7 +73,11 @@ class LoRAManager:
...
@@ -75,7 +73,11 @@ class LoRAManager:
self
.
lora_backend
:
BaseLoRABackend
=
backend_type
(
lora_backend
)
self
.
lora_backend
:
BaseLoRABackend
=
backend_type
(
lora_backend
)
# Initialize mutable internal state of the LoRAManager.
# Initialize mutable internal state of the LoRAManager.
self
.
init_state
()
self
.
init_state
(
max_lora_rank
=
max_lora_rank
,
target_modules
=
target_modules
,
lora_paths
=
lora_paths
,
)
def
init_cuda_graph_batch_info
(
self
,
max_bs_in_cuda_graph
:
int
):
def
init_cuda_graph_batch_info
(
self
,
max_bs_in_cuda_graph
:
int
):
self
.
max_bs_in_cuda_graph
=
max_bs_in_cuda_graph
self
.
max_bs_in_cuda_graph
=
max_bs_in_cuda_graph
...
@@ -112,108 +114,87 @@ class LoRAManager:
...
@@ -112,108 +114,87 @@ class LoRAManager:
success
=
success
,
success
=
success
,
error_message
=
error_message
,
error_message
=
error_message
,
loaded_adapters
=
{
loaded_adapters
=
{
name
:
config
.
path
for
name
,
config
in
self
.
configs
.
items
()
lora_ref
.
lora_name
:
lora_ref
.
lora_path
for
lora_ref
in
self
.
lora_refs
.
values
()
},
},
)
)
def
load_lora_adapters
(
self
,
lora_paths
:
Dict
[
str
,
str
])
->
LoRAUpdateResult
:
def
load_lora_adapter
(
self
,
lora_ref
:
LoRARef
)
->
LoRAUpdateResult
:
"""
Load LoRA adapters from the specified paths.
Args:
lora_paths (Dict[str, str]): A dictionary mapping LoRA adapter names to their file paths.
If a LoRA adapter is already loaded, it will be skipped with a warning.
"""
results
=
[]
for
lora_name
,
lora_path
in
lora_paths
.
items
():
result
=
self
.
load_lora_adapter
(
lora_name
,
lora_path
,
update_state
=
False
)
results
.
append
(
result
)
self
.
update_state_from_configs
()
return
self
.
create_lora_update_result
(
success
=
all
(
result
.
success
for
result
in
results
),
error_message
=
"
\n
"
.
join
(
result
.
error_message
for
result
in
results
if
not
result
.
success
),
)
def
load_lora_adapter
(
self
,
lora_name
:
str
,
lora_path
:
str
,
update_state
:
bool
=
True
)
->
LoRAUpdateResult
:
"""
"""
Load a single LoRA adapter from the specified path.
Load a single LoRA adapter from the specified path.
Args:
Args:
lora_name (str): The name of the LoRA adapter.
lora_ref (LoRARef): The LoRARef object containing the LoRA name, path, and ID.
lora_path (str): The file path to the LoRA adapter.
update_state (bool): Whether to refresh the internal state after loading the adapter. This is useful for batch loading.
"""
"""
assert
(
lora_ref
.
lora_name
is
not
None
and
lora_ref
.
lora_path
is
not
None
),
"LoRARef must have both lora_name and lora_path set for loading."
assert
(
lora_ref
.
lora_id
not
in
self
.
loras
),
f
"LoRA adapter with ID
{
lora_ref
.
lora_id
}
is already loaded. This should have been verified before request is sent to the backend."
success
=
True
try
:
error_message
=
""
# load configs
new_adapter
=
LoRAConfig
(
lora_ref
.
lora_path
)
self
.
validate_new_adapter
(
new_adapter
,
lora_ref
)
self
.
configs
[
lora_ref
.
lora_id
]
=
new_adapter
if
lora_name
in
self
.
loras
:
# load weights
success
=
False
self
.
load_lora_weights
(
lora_ref
)
error_message
=
f
"LoRA adapter
{
lora_name
}
is skipped as it is already loaded. If you want to reload it, please unload it first."
try
:
# keep metadata for displayed messages
new_adapter
=
LoRAConfig
(
lora_path
)
self
.
lora_refs
[
lora_ref
.
lora_id
]
=
lora_ref
self
.
validate_new_adapter
(
lora_name
,
new_adapter
)
self
.
configs
[
lora_name
]
=
new_adapter
except
Exception
as
e
:
except
Exception
as
e
:
success
=
False
return
self
.
create_lora_update_result
(
error_message
=
(
success
=
False
,
f
"Failed to load LoRA adapter
{
lora_name
}
from
{
lora_path
}
:
{
str
(
e
)
}
"
error_message
=
str
(
e
)
,
)
)
if
update_state
:
return
self
.
create_lora_update_result
(
success
=
True
)
self
.
update_state_from_configs
()
return
self
.
create_lora_update_result
(
def
validate_new_adapter
(
self
,
lora_config
:
LoRAConfig
,
lora_ref
:
LoRARef
):
success
=
success
,
error_message
=
error_message
,
)
def
validate_new_adapter
(
self
,
lora_name
:
str
,
lora_config
:
LoRAConfig
):
"""
"""
Validate if an adapter can be loaded into the current LoRA memory pool and generate error if it is incompatible.
Validate if an adapter can be loaded into the current LoRA memory pool and generate error if it is incompatible.
"""
"""
incompatible
=
self
.
memory_pool
and
not
self
.
memory_pool
.
can_support
(
memory_pool
=
getattr
(
self
,
"memory_pool"
,
None
)
lora_config
incompatible
=
memory_pool
and
not
memory_pool
.
can_support
(
lora_config
)
)
if
incompatible
:
if
incompatible
:
raise
ValueError
(
raise
ValueError
(
f
"LoRA adapter
{
lora_name
}
with rank
{
lora_config
.
r
}
is incompatible with the current LoRA memory pool configuration. "
f
"LoRA adapter
{
lora_
ref
.
lora_
name
}
with rank
{
lora_config
.
r
}
is incompatible with the current LoRA memory pool configuration. "
"Please ensure that the LoRA adapter's rank is within the configured `--max_lora_rank` and that the target modules are "
"Please ensure that the LoRA adapter's rank is within the configured `--max_lora_rank` and that the target modules are "
"included in `--enable_lora_modules`."
"included in `--enable_lora_modules`."
)
)
def
unload_lora_adapter
(
self
,
lora_
name
:
str
)
->
LoRAUpdateResult
:
def
unload_lora_adapter
(
self
,
lora_
ref
:
LoRARef
)
->
LoRAUpdateResult
:
"""
"""
Unload LoRA adapters by their names. This will remove the adapters from the memory pool and
Unload LoRA adapters by their names. This will remove the adapters from the memory pool and
delete the corresponding LoRA modules.
delete the corresponding LoRA modules.
"""
"""
success
=
True
adapter
=
self
.
configs
.
get
(
lora_ref
.
lora_id
,
None
)
error_message
=
""
assert
(
if
lora_name
in
self
.
loras
:
adapter
is
not
None
del
self
.
configs
[
lora_name
]
),
f
"LoRA adapter with ID
{
lora_ref
.
lora_id
}
is not loaded. This should have been verified before request is sent to the backend."
else
:
error_message
=
f
"LoRA adapter
{
lora_name
}
is not loaded."
success
=
False
self
.
update_state_from_configs
()
try
:
del
self
.
configs
[
lora_ref
.
lora_id
]
del
self
.
loras
[
lora_ref
.
lora_id
]
del
self
.
lora_refs
[
lora_ref
.
lora_id
]
except
Exception
as
e
:
return
self
.
create_lora_update_result
(
success
=
False
,
error_message
=
str
(
e
),
)
return
self
.
create_lora_update_result
(
return
self
.
create_lora_update_result
(
success
=
True
)
success
=
success
,
error_message
=
error_message
,
)
def
prepare_lora_batch
(
self
,
forward_batch
:
ForwardBatch
):
def
prepare_lora_batch
(
self
,
forward_batch
:
ForwardBatch
):
# load active loras into lora memory pool
# Load active loras into lora memory pool
# TODO (lifuhuang): The naming of `forward_batch.lora_paths` is confusing. It actually contains a set of unique
# LoRA IDs, not LoRA paths. While unfortunately we cannot change the name in API for backward compatibility, we
# should consider (1) renaming the incorrect usage within the system, and (2) deprecating the parameter name in
# the current API schema and introducing a better request schema in the future (e.g., use `model_name`).
cur_uids
=
set
(
forward_batch
.
lora_paths
)
cur_uids
=
set
(
forward_batch
.
lora_paths
)
assert
len
(
cur_uids
)
<=
self
.
max_loras_per_batch
assert
len
(
cur_uids
)
<=
self
.
max_loras_per_batch
self
.
memory_pool
.
prepare_lora_batch
(
cur_uids
,
self
.
loras
,
self
.
lora_modules
)
self
.
memory_pool
.
prepare_lora_batch
(
cur_uids
,
self
.
loras
,
self
.
lora_modules
)
...
@@ -233,10 +214,10 @@ class LoRAManager:
...
@@ -233,10 +214,10 @@ class LoRAManager:
weight_indices
=
[
0
]
*
len
(
forward_batch
.
lora_paths
)
weight_indices
=
[
0
]
*
len
(
forward_batch
.
lora_paths
)
lora_ranks
=
[
0
]
*
self
.
max_loras_per_batch
lora_ranks
=
[
0
]
*
self
.
max_loras_per_batch
scalings
=
[
0
]
*
self
.
max_loras_per_batch
scalings
=
[
0
]
*
self
.
max_loras_per_batch
for
i
,
lora_path
in
enumerate
(
forward_batch
.
lora_paths
):
for
i
,
uid
in
enumerate
(
forward_batch
.
lora_paths
):
weight_indices
[
i
]
=
self
.
memory_pool
.
get_buffer_id
(
lora_path
)
weight_indices
[
i
]
=
self
.
memory_pool
.
get_buffer_id
(
uid
)
if
lora_path
is
not
None
:
if
uid
is
not
None
:
lora
=
self
.
loras
[
lora_path
]
lora
=
self
.
loras
[
uid
]
lora_ranks
[
weight_indices
[
i
]]
=
lora
.
config
.
r
lora_ranks
[
weight_indices
[
i
]]
=
lora
.
config
.
r
scalings
[
weight_indices
[
i
]]
=
lora
.
scaling
scalings
[
weight_indices
[
i
]]
=
lora
.
scaling
...
@@ -326,7 +307,7 @@ class LoRAManager:
...
@@ -326,7 +307,7 @@ class LoRAManager:
"""
"""
Update all LoRA modules to associate them with the latest memory buffer.
Update all LoRA modules to associate them with the latest memory buffer.
"""
"""
for
layer_id
,
layer_modules
in
self
.
lora_modules
.
items
(
):
for
layer_id
,
layer_modules
in
enumerate
(
self
.
lora_modules
):
for
module_name
,
module
in
layer_modules
.
items
():
for
module_name
,
module
in
layer_modules
.
items
():
if
"qkv_proj"
in
module_name
:
if
"qkv_proj"
in
module_name
:
module
.
set_lora_info
(
module
.
set_lora_info
(
...
@@ -353,115 +334,94 @@ class LoRAManager:
...
@@ -353,115 +334,94 @@ class LoRAManager:
),
),
)
)
def
init_state
(
self
):
def
init_state
(
self
,
max_lora_rank
:
Optional
[
int
]
=
None
,
target_modules
:
Optional
[
Iterable
[
str
]]
=
None
,
lora_paths
:
Optional
[
Dict
[
str
,
LoRARef
]]
=
None
,
):
"""
"""
Initialize the internal (mutable) state of the LoRAManager.
Initialize the internal (mutable) state of the LoRAManager.
These states are mutable via the `update_state_from_configs` as LoRA adapters are loaded and unloaded dynamically.
When `lora_paths` is provided and not empty, it might be used for inferring LoRA shape info such as
the target modules and max_lora_rank.
"""
"""
# Configs of all active LoRA adapters.
assert
lora_paths
or
(
self
.
configs
:
Dict
[
str
,
LoRAConfig
]
=
{}
max_lora_rank
is
not
None
and
target_modules
is
not
None
),
"When no initial --lora-paths is provided, you need to specify both --max-lora-rank and --lora-target-modules for LoRA initialization."
# LoRA adapter weights cached in CPU memory.
self
.
loras
:
Dict
[
str
,
LoRAAdapter
]
=
{}
# Supported weight names (e.g., qkv_proj) for LoRA A and B respectively.
self
.
init_lora_adapters
(
lora_paths
)
self
.
lora_weight_names
:
Tuple
[
Set
[
str
]]
=
(
set
(),
set
())
self
.
init_lora_shapes
(
max_lora_rank
=
max_lora_rank
,
# Look-up table that essentially maps (layer_index, module_name) to the corresponding LoRA module.
target_modules
=
target_modules
,
self
.
lora_modules
:
Dict
[
int
,
Dict
[
str
,
BaseLayerWithLoRA
]]
=
{
)
i
:
{}
for
i
in
range
(
self
.
base_hf_config
.
num_hidden_layers
)
self
.
init_lora_weight_names
()
}
self
.
init_lora_modules
()
self
.
init_memory_pool
()
# The LoRA memory pool that manages the GPU buffers for active LoRA weights.
def
init_lora_adapters
(
self
,
lora_paths
:
Optional
[
Dict
[
str
,
LoRARef
]]
=
None
):
#
It is initialized lazily when the first LoRA adapter is loaded
.
#
Configs of all active LoRA adapters, indexed by LoRA ID
.
self
.
memory_pool
:
Optional
[
LoRAMemoryPool
]
=
None
self
.
configs
:
Dict
[
str
,
LoRAConfig
]
=
{}
def
update_state_from_configs
(
self
):
# LoRA adapter weights cached in CPU memory, indexed by LoRA ID.
"""
self
.
loras
:
Dict
[
str
,
LoRAAdapter
]
=
{}
Update the internal state of the LoRAManager based on the current `self.configs`. This method
should be called whenever `self.configs` is modified (e.g., when new LoRA adapters are loaded).
"""
# Loads / unloads LoRA adapters based on the latest configs.
# Mapping from LoRA ID to LoRARef object.
self
.
update_lora_adapters
()
self
.
lora_refs
:
Dict
[
str
,
LoRARef
]
=
{}
# Apply the latest LoRA configurations to the internal state for inferencing.
self
.
apply_lora_configs
()
def
apply_lora_configs
(
self
):
if
lora_paths
:
"""
for
lora_ref
in
lora_paths
.
values
():
Apply the LoRA configurations to the base model and internal states of the LoRAManager for inferencing.
result
=
self
.
load_lora_adapter
(
lora_ref
)
if
not
result
.
success
:
raise
RuntimeError
(
f
"Failed to load LoRA adapter
{
lora_ref
.
lora_name
}
:
{
result
.
error_message
}
"
)
Not
es
:
def
init_lora_shap
es
(
- Currently, this method is effectively only invoked during the initialization phase of the LoRAManager as
self
,
we do not yet support dynamically updating adapter shape configs, which has a dependency on (1) FlashInfer
max_lora_rank
:
Optional
[
int
]
=
None
,
LoRA backend deprecation and (2) CUDA graph recapture support. We are targeting completing these work in
target_modules
:
Optional
[
Iterable
[
str
]]
=
None
,
early CY25H2.
):
"""
"""
Infer LoRA target modules and max_lora_rank from loaded adapters if not provided."""
if
self
.
memory_pool
is
None
:
if
target_modules
is
not
None
:
# Infer max_lora_rank and target_modules if not explicitly specified in server args.
self
.
target_modules
=
set
(
target_modules
)
if
self
.
target_modules
is
None
:
else
:
self
.
target_modules
=
set
()
self
.
target_modules
=
set
()
for
config
in
self
.
configs
.
values
():
for
config
in
self
.
configs
.
values
():
self
.
target_modules
.
update
(
config
.
target_modules
)
self
.
target_modules
.
update
(
config
.
target_modules
)
if
self
.
max_lora_rank
is
None
:
self
.
max_lora_rank
=
max
(
[
x
.
hf_config
[
"r"
]
for
x
in
self
.
configs
.
values
()],
default
=
0
,
)
self
.
update_lora_weight_names
()
if
max_lora_rank
is
not
None
:
self
.
update_lora_modules
()
self
.
max_lora_rank
=
max_lora_rank
self
.
update_memory_buffers
()
else
:
else
:
# No-op if the memory pool can support the current LoRA configurations.
self
.
max_lora_rank
=
max
(
# TODO (lifuhuang): support reinitializing the memory pool when the maximum LoRA rank or target
[
x
.
hf_config
[
"r"
]
for
x
in
self
.
configs
.
values
()],
# module is changed once FlashInfer backend is deprecated.
default
=
0
,
assert
self
.
memory_pool
.
can_support
(
self
.
configs
.
values
()),
(
"LoRA memory pool cannot support the current LoRA configuration. "
"This should never happen as we should have validated adapter compatibility. "
"Please create a Github issue to report."
,
)
)
def
update
_lora_weight_names
(
self
):
def
init
_lora_weight_names
(
self
):
"""
"""
Add new LoRA weight names if needed based on the current `self.configs`.
Add new LoRA weight names if needed based on the current `self.configs`.
"""
"""
# Target lora weight names for lora_a and lora_b modules respectively.
# Target lora weight names for lora_a and lora_b modules respectively.
lora_A
,
lora_B
=
get_normalized_lora_weight_names
(
self
.
target_modules
)
lora_A
,
lora_B
=
get_normalized_lora_weight_names
(
self
.
target_modules
)
self
.
lora_weight_names
[
0
].
update
(
lora_A
)
self
.
lora_weight_names
:
Tuple
[
Set
[
str
]]
=
(
set
(
lora_A
),
set
(
lora_B
))
self
.
lora_weight_names
[
1
].
update
(
lora_B
)
def
update_lora_adapters
(
sel
f
):
def
load_lora_weights
(
self
,
lora_ref
:
LoRARe
f
):
"""
"""
Update the LoRA adapters in CPU memory based on the current `self.configs`.
Load the weights of a LoRA adapter to CPU memory and conducts post-loading validation.
It loads any new adapters that are not already loaded, and unloads any adapters
that are no longer in `self.configs` (e.g., unloaded).
"""
"""
lora_adapter
=
LoRAAdapter
(
# Load new adapter weights to cpu
lora_ref
.
lora_id
,
for
name
,
config
in
self
.
configs
.
items
():
self
.
configs
[
lora_ref
.
lora_id
],
if
name
not
in
self
.
loras
:
self
.
base_hf_config
,
logger
.
info
(
f
"Loading weight of LoRA adapter
{
name
}
from
{
config
.
path
}
"
)
self
.
load_config
,
lora_adapter
=
LoRAAdapter
(
self
.
lora_backend
,
name
,
)
config
,
lora_adapter
.
initialize_weights
()
self
.
base_hf_config
,
self
.
loras
[
lora_ref
.
lora_id
]
=
lora_adapter
self
.
load_config
,
self
.
lora_backend
,
)
lora_adapter
.
initialize_weights
()
self
.
loras
[
name
]
=
lora_adapter
# Clean up unused LoRA adapters, copying the list to avoid modifying the dict during iteration.
for
name
in
list
(
self
.
loras
):
if
name
not
in
self
.
configs
:
logger
.
info
(
f
"Unloading LoRA adapter
{
name
}
"
)
del
self
.
loras
[
name
]
# Additional checks for flashinfer backend
# Additional checks for flashinfer backend
# FIXME remove the restrictions after supporting multi-rank for flashinfer backend
# FIXME remove the restrictions after supporting multi-rank for flashinfer backend
...
@@ -472,7 +432,7 @@ class LoRAManager:
...
@@ -472,7 +432,7 @@ class LoRAManager:
len
(
lora_dims
)
==
1
and
len
(
scalings
)
==
1
len
(
lora_dims
)
==
1
and
len
(
scalings
)
==
1
),
"Flashinfer backend currently only supports single LoRA rank and scaling across all adapters. "
),
"Flashinfer backend currently only supports single LoRA rank and scaling across all adapters. "
def
update
_memory_
buffers
(
self
):
def
init
_memory_
pool
(
self
):
"""(Re)initialize the LoRA memory pool based on the current configurations."""
"""(Re)initialize the LoRA memory pool based on the current configurations."""
self
.
memory_pool
=
LoRAMemoryPool
(
self
.
memory_pool
=
LoRAMemoryPool
(
base_hf_config
=
self
.
base_hf_config
,
base_hf_config
=
self
.
base_hf_config
,
...
@@ -490,7 +450,12 @@ class LoRAManager:
...
@@ -490,7 +450,12 @@ class LoRAManager:
replace_submodule
(
self
.
base_model
,
module_name
,
lora_module
)
replace_submodule
(
self
.
base_model
,
module_name
,
lora_module
)
return
lora_module
return
lora_module
def
update_lora_modules
(
self
):
def
init_lora_modules
(
self
):
# Look-up table that essentially maps (layer_index, module_name) to the corresponding LoRA module.
self
.
lora_modules
:
List
[
Dict
[
str
,
BaseLayerWithLoRA
]]
=
[
{}
for
_
in
range
(
self
.
base_hf_config
.
num_hidden_layers
)
]
# Target module names of customized layers defined in python/sglang/srt/layers
# Target module names of customized layers defined in python/sglang/srt/layers
# e.g., {"qkv_proj", "o_proj"}
# e.g., {"qkv_proj", "o_proj"}
customized_target_names
=
get_customized_names_from_hf_names
(
customized_target_names
=
get_customized_names_from_hf_names
(
...
@@ -511,7 +476,6 @@ class LoRAManager:
...
@@ -511,7 +476,6 @@ class LoRAManager:
# The module should be converted if it is included in target_names
# The module should be converted if it is included in target_names
if
module_name
.
split
(
"."
)[
-
1
]
in
customized_target_names
:
if
module_name
.
split
(
"."
)[
-
1
]
in
customized_target_names
:
layer_id
=
get_layer_id
(
module_name
)
layer_id
=
get_layer_id
(
module_name
)
if
module_name
not
in
self
.
lora_modules
[
layer_id
]:
self
.
lora_modules
[
layer_id
][
module_name
]
=
self
.
set_lora_module
(
self
.
lora_modules
[
layer_id
][
module_name
]
=
self
.
set_lora_module
(
module_name
,
module
module_name
,
module
)
)
python/sglang/srt/lora/lora_registry.py
0 → 100644
View file @
8abd3e77
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import
asyncio
from
dataclasses
import
dataclass
,
field
,
fields
from
typing
import
Dict
,
List
,
Optional
,
Union
from
uuid
import
uuid4
@
dataclass
(
frozen
=
True
,
slots
=
True
)
class
LoRARef
:
"""
Reference record for a LoRA model.
This object guarantees a unique ``lora_id`` and may include ``lora_name`` and ``lora_path``. The ID
eliminates conflicts from reused LoRA names or paths and can be used to generate deterministic cache
keys (e.g., radix cache).
"""
lora_id
:
str
=
field
(
default_factory
=
lambda
:
uuid4
().
hex
)
lora_name
:
Optional
[
str
]
=
None
lora_path
:
Optional
[
str
]
=
None
def
__post_init__
(
self
):
if
self
.
lora_id
is
None
:
raise
ValueError
(
"lora_id cannot be None"
)
def
__str__
(
self
)
->
str
:
parts
=
[
f
"
{
f
.
name
}
=
{
value
}
"
for
f
in
fields
(
self
)
if
(
value
:
=
getattr
(
self
,
f
.
name
))
is
not
None
]
return
f
"
{
self
.
__class__
.
__name__
}
(
{
', '
.
join
(
parts
)
}
)"
class
LoRARegistry
:
"""
The central registry to keep track of available LoRA adapters.
TODO (lifuhuang): This registry is intended as the foundation for overlapped lora update. We decided
to keep it in a separate PR to keep code review simple and to unblock the radix cache work.
"""
def
__init__
(
self
,
lora_paths
:
Optional
[
Dict
[
str
,
LoRARef
]]
=
None
):
assert
lora_paths
is
None
or
all
(
isinstance
(
lora
,
LoRARef
)
for
lora
in
lora_paths
.
values
()
),
(
"server_args.lora_paths should have been normalized to LoRARef objects during server initialization. "
"Please file an issue if you see this error."
)
# A dictionary to hold LoRARef objects, mapping from LoRA name to LoRARef.
self
.
_registry
:
Dict
[
str
,
LoRARef
]
=
dict
(
lora_paths
or
{})
async
def
register
(
self
,
lora_ref
:
LoRARef
):
"""
Register a new LoRARef object in the registry.
Args:
lora_ref (LoRARef): The LoRARef object to register.
"""
if
lora_ref
.
lora_name
in
self
.
_registry
:
raise
ValueError
(
f
"LoRA with name
{
lora_ref
.
lora_name
}
already exists. Loaded LoRAs:
{
self
.
_registry
.
keys
()
}
"
)
self
.
_registry
[
lora_ref
.
lora_name
]
=
lora_ref
async
def
unregister
(
self
,
lora_name
:
str
)
->
str
:
"""
Unregister a LoRARef object from the registry and returns the removed LoRA ID.
Args:
lora_name (str): The name of the LoRA model to unregister.
"""
lora_ref
=
self
.
_registry
.
get
(
lora_name
,
None
)
if
lora_ref
is
None
:
raise
ValueError
(
f
"LoRA with name
{
lora_name
}
does not exist. Loaded LoRAs:
{
self
.
_registry
.
keys
()
}
"
)
del
self
.
_registry
[
lora_name
]
return
lora_ref
.
lora_id
async
def
acquire
(
self
,
lora_name
:
Union
[
str
,
List
[
str
]])
->
Union
[
str
,
List
[
str
]]:
"""
Queries registry for LoRA IDs based on LoRA names and start tracking the usage of the corresponding LoRA adapters
by incrementing its counter.
TODO (lifuhuang): currently it only queries the registry and does not track the usage of LoRA adapters.
"""
async
def
_acquire_single
(
name
:
str
)
->
str
:
lora_ref
=
self
.
_registry
.
get
(
name
,
None
)
if
lora_ref
is
None
:
raise
ValueError
(
f
"The following requested LoRA adapters are not loaded:
{
name
}
\n
"
f
"Loaded adapters:
{
self
.
_registry
.
keys
()
}
."
)
# await self._counters[lora_ref.lora_id].increment()
return
lora_ref
.
lora_id
if
isinstance
(
lora_name
,
str
):
lora_id
=
await
_acquire_single
(
lora_name
)
return
lora_id
elif
isinstance
(
lora_name
,
list
):
lora_ids
=
await
asyncio
.
gather
(
*
[
_acquire_single
(
name
)
for
name
in
lora_name
]
)
return
lora_ids
else
:
raise
TypeError
(
"lora_name must be either a string or a list of strings."
)
python/sglang/srt/lora/mem_pool.py
View file @
8abd3e77
...
@@ -153,7 +153,7 @@ class LoRAMemoryPool:
...
@@ -153,7 +153,7 @@ class LoRAMemoryPool:
self
,
self
,
cur_uids
:
Set
[
Optional
[
str
]],
cur_uids
:
Set
[
Optional
[
str
]],
lora_adapters
:
Dict
[
str
,
LoRAAdapter
],
lora_adapters
:
Dict
[
str
,
LoRAAdapter
],
lora_modules
:
Dict
[
int
,
Dict
[
str
,
BaseLayerWithLoRA
]],
lora_modules
:
List
[
Dict
[
str
,
BaseLayerWithLoRA
]],
):
):
def
get_available_buffer_slot
():
def
get_available_buffer_slot
():
for
buffer_id
in
range
(
self
.
max_loras_per_batch
):
for
buffer_id
in
range
(
self
.
max_loras_per_batch
):
...
@@ -186,7 +186,7 @@ class LoRAMemoryPool:
...
@@ -186,7 +186,7 @@ class LoRAMemoryPool:
uid
:
str
,
uid
:
str
,
buffer_id
:
int
,
buffer_id
:
int
,
lora_adapter
:
LoRAAdapter
,
lora_adapter
:
LoRAAdapter
,
lora_modules
:
Dict
[
int
,
Dict
[
str
,
BaseLayerWithLoRA
]],
lora_modules
:
List
[
Dict
[
str
,
BaseLayerWithLoRA
]],
):
):
def
load_lora_weight_tensor
(
def
load_lora_weight_tensor
(
buffer_view
:
torch
.
Tensor
,
weight
:
Optional
[
torch
.
Tensor
]
buffer_view
:
torch
.
Tensor
,
weight
:
Optional
[
torch
.
Tensor
]
...
...
python/sglang/srt/managers/io_struct.py
View file @
8abd3e77
...
@@ -22,6 +22,7 @@ from dataclasses import dataclass, field
...
@@ -22,6 +22,7 @@ from dataclasses import dataclass, field
from
enum
import
Enum
from
enum
import
Enum
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Union
from
sglang.srt.lora.lora_registry
import
LoRARef
from
sglang.srt.managers.schedule_batch
import
BaseFinishReason
from
sglang.srt.managers.schedule_batch
import
BaseFinishReason
from
sglang.srt.multimodal.mm_utils
import
has_valid_data
from
sglang.srt.multimodal.mm_utils
import
has_valid_data
from
sglang.srt.sampling.sampling_params
import
SamplingParams
from
sglang.srt.sampling.sampling_params
import
SamplingParams
...
@@ -1067,19 +1068,36 @@ class LoadLoRAAdapterReqInput:
...
@@ -1067,19 +1068,36 @@ class LoadLoRAAdapterReqInput:
lora_name
:
str
lora_name
:
str
# The path of loading.
# The path of loading.
lora_path
:
str
lora_path
:
str
# The unique identifier for the LoRA adapter, which automatically generated in the `TokenizerManager`.
lora_id
:
Optional
[
str
]
=
None
def
to_ref
(
self
)
->
LoRARef
:
return
LoRARef
(
lora_id
=
self
.
lora_id
,
lora_name
=
self
.
lora_name
,
lora_path
=
self
.
lora_path
,
)
@
dataclass
@
dataclass
class
UnloadLoRAAdapterReqInput
:
class
UnloadLoRAAdapterReqInput
:
# The name of lora module to unload.
# The name of lora module to unload.
lora_name
:
str
lora_name
:
str
# The unique identifier for the LoRA adapter, which automatically generated in the `TokenizerManager`.
lora_id
:
Optional
[
str
]
=
None
def
to_ref
(
self
)
->
LoRARef
:
return
LoRARef
(
lora_id
=
self
.
lora_id
,
lora_name
=
self
.
lora_name
,
)
@
dataclass
@
dataclass
class
LoRAUpdateResult
:
class
LoRAUpdateResult
:
success
:
bool
success
:
bool
error_message
:
Optional
[
str
]
=
None
error_message
:
Optional
[
str
]
=
None
loaded_adapters
:
Dict
[
str
,
str
]
=
field
(
default_factory
=
dict
)
loaded_adapters
:
Dict
[
str
,
LoRARef
]
=
field
(
default_factory
=
dict
)
LoadLoRAAdapterReqOutput
=
UnloadLoRAAdapterReqOutput
=
LoRAUpdateResult
LoadLoRAAdapterReqOutput
=
UnloadLoRAAdapterReqOutput
=
LoRAUpdateResult
python/sglang/srt/managers/scheduler.py
View file @
8abd3e77
...
@@ -247,7 +247,7 @@ class Scheduler(
...
@@ -247,7 +247,7 @@ class Scheduler(
self
.
pp_size
=
server_args
.
pp_size
self
.
pp_size
=
server_args
.
pp_size
self
.
dp_size
=
server_args
.
dp_size
self
.
dp_size
=
server_args
.
dp_size
self
.
schedule_policy
=
server_args
.
schedule_policy
self
.
schedule_policy
=
server_args
.
schedule_policy
self
.
lora_paths
=
server_args
.
lora_paths
self
.
enable_lora
=
server_args
.
enable_lora
self
.
max_loras_per_batch
=
server_args
.
max_loras_per_batch
self
.
max_loras_per_batch
=
server_args
.
max_loras_per_batch
self
.
enable_overlap
=
not
server_args
.
disable_overlap_schedule
self
.
enable_overlap
=
not
server_args
.
disable_overlap_schedule
self
.
skip_tokenizer_init
=
server_args
.
skip_tokenizer_init
self
.
skip_tokenizer_init
=
server_args
.
skip_tokenizer_init
...
@@ -1706,13 +1706,13 @@ class Scheduler(
...
@@ -1706,13 +1706,13 @@ class Scheduler(
self
.
chunked_req
.
init_next_round_input
()
self
.
chunked_req
.
init_next_round_input
()
self
.
chunked_req
=
adder
.
add_chunked_req
(
self
.
chunked_req
)
self
.
chunked_req
=
adder
.
add_chunked_req
(
self
.
chunked_req
)
if
self
.
lora_paths
:
if
self
.
enable_lora
:
lora_set
=
set
([
req
.
lora_path
for
req
in
self
.
running_batch
.
reqs
])
lora_set
=
set
([
req
.
lora_path
for
req
in
self
.
running_batch
.
reqs
])
# Get requests from the waiting queue to a new prefill batch
# Get requests from the waiting queue to a new prefill batch
for
req
in
self
.
waiting_queue
:
for
req
in
self
.
waiting_queue
:
if
(
if
(
self
.
lora_paths
self
.
enable_lora
and
len
(
and
len
(
lora_set
lora_set
|
set
([
req
.
lora_path
for
req
in
adder
.
can_run_list
])
|
set
([
req
.
lora_path
for
req
in
adder
.
can_run_list
])
...
@@ -2466,12 +2466,6 @@ class Scheduler(
...
@@ -2466,12 +2466,6 @@ class Scheduler(
"""In-place loading a new lora adapter from disk or huggingface."""
"""In-place loading a new lora adapter from disk or huggingface."""
result
=
self
.
tp_worker
.
load_lora_adapter
(
recv_req
)
result
=
self
.
tp_worker
.
load_lora_adapter
(
recv_req
)
if
result
.
success
:
flush_cache_success
=
self
.
flush_cache
()
assert
flush_cache_success
,
"Cache flush failed after loading lora adapter."
else
:
logger
.
error
(
result
.
error_message
)
return
result
return
result
def
unload_lora_adapter
(
def
unload_lora_adapter
(
...
@@ -2480,14 +2474,6 @@ class Scheduler(
...
@@ -2480,14 +2474,6 @@ class Scheduler(
"""Unload the lora adapter."""
"""Unload the lora adapter."""
result
=
self
.
tp_worker
.
unload_lora_adapter
(
recv_req
)
result
=
self
.
tp_worker
.
unload_lora_adapter
(
recv_req
)
if
result
.
success
:
flush_cache_success
=
self
.
flush_cache
()
assert
(
flush_cache_success
),
"Cache flush failed after unloading LoRA weights"
else
:
logger
.
error
(
result
.
error_message
)
return
result
return
result
def
init_weights_update_group
(
self
,
recv_req
:
InitWeightsUpdateGroupReqInput
):
def
init_weights_update_group
(
self
,
recv_req
:
InitWeightsUpdateGroupReqInput
):
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
8abd3e77
...
@@ -62,6 +62,7 @@ from sglang.srt.hf_transformers_utils import (
...
@@ -62,6 +62,7 @@ from sglang.srt.hf_transformers_utils import (
get_tokenizer
,
get_tokenizer
,
get_tokenizer_from_processor
,
get_tokenizer_from_processor
,
)
)
from
sglang.srt.lora.lora_registry
import
LoRARef
,
LoRARegistry
from
sglang.srt.managers.io_struct
import
(
from
sglang.srt.managers.io_struct
import
(
AbortReq
,
AbortReq
,
BatchEmbeddingOut
,
BatchEmbeddingOut
,
...
@@ -242,11 +243,11 @@ class TokenizerManager:
...
@@ -242,11 +243,11 @@ class TokenizerManager:
revision
=
server_args
.
revision
,
revision
=
server_args
.
revision
,
)
)
# Initialize
loaded loRA adapters
with
the
initial
lora paths in the
server_args.
# Initialize
the `LoRARegistry`
with initial
LoRA adapter paths provided in `
server_args
`
.
# Th
is list will be updated when new LoRA
adapters are loaded
or
unloaded d
ynamically.
# Th
e registry dynamically updates as
adapters are loaded
/
unloaded d
uring runtime. It
self
.
loaded_lora_adapters
:
Dict
[
str
,
str
]
=
dict
(
# serves as the source of truth for available adapters and maps user-friendly LoRA names
self
.
server_args
.
lora_paths
or
{}
# to internally used unique LoRA IDs.
)
self
.
lora_registry
=
LoRARegistry
(
self
.
server_args
.
lora_paths
or
{}
)
# Store states
# Store states
self
.
no_create_loop
=
False
self
.
no_create_loop
=
False
...
@@ -523,6 +524,10 @@ class TokenizerManager:
...
@@ -523,6 +524,10 @@ class TokenizerManager:
else
:
else
:
mm_inputs
=
None
mm_inputs
=
None
if
self
.
server_args
.
enable_lora
and
obj
.
lora_path
:
# Replace the user-friendly LoRA names in `lora_path` with their corresponding unique LoRA IDs.
obj
.
lora_path
=
await
self
.
lora_registry
.
acquire
(
obj
.
lora_path
)
self
.
_validate_one_request
(
obj
,
input_ids
)
self
.
_validate_one_request
(
obj
,
input_ids
)
return
self
.
_create_tokenized_object
(
return
self
.
_create_tokenized_object
(
obj
,
input_text
,
input_ids
,
input_embeds
,
mm_inputs
,
token_type_ids
obj
,
input_text
,
input_ids
,
input_embeds
,
mm_inputs
,
token_type_ids
...
@@ -574,8 +579,6 @@ class TokenizerManager:
...
@@ -574,8 +579,6 @@ class TokenizerManager:
"The server is not configured to enable custom logit processor. "
"The server is not configured to enable custom logit processor. "
"Please set `--enable-custom-logits-processor` to enable this feature."
"Please set `--enable-custom-logits-processor` to enable this feature."
)
)
if
self
.
server_args
.
enable_lora
and
obj
.
lora_path
:
self
.
_validate_lora_adapters
(
obj
)
def
_validate_input_ids_in_vocab
(
def
_validate_input_ids_in_vocab
(
self
,
input_ids
:
List
[
int
],
vocab_size
:
int
self
,
input_ids
:
List
[
int
],
vocab_size
:
int
...
@@ -689,21 +692,6 @@ class TokenizerManager:
...
@@ -689,21 +692,6 @@ class TokenizerManager:
"Batch tokenization is not needed for input_embeds. Do not set `enable_tokenizer_batch_encode`."
"Batch tokenization is not needed for input_embeds. Do not set `enable_tokenizer_batch_encode`."
)
)
def
_validate_lora_adapters
(
self
,
obj
:
GenerateReqInput
):
"""Validate that the requested LoRA adapters are loaded."""
requested_adapters
=
(
set
(
obj
.
lora_path
)
if
isinstance
(
obj
.
lora_path
,
list
)
else
{
obj
.
lora_path
}
)
loaded_adapters
=
(
self
.
loaded_lora_adapters
.
keys
()
if
self
.
loaded_lora_adapters
else
set
()
)
unloaded_adapters
=
requested_adapters
-
loaded_adapters
if
unloaded_adapters
:
raise
ValueError
(
f
"The following requested LoRA adapters are not loaded:
{
unloaded_adapters
}
\n
"
f
"Loaded adapters:
{
loaded_adapters
}
."
)
def
_send_one_request
(
def
_send_one_request
(
self
,
self
,
obj
:
Union
[
GenerateReqInput
,
EmbeddingReqInput
],
obj
:
Union
[
GenerateReqInput
,
EmbeddingReqInput
],
...
@@ -1054,8 +1042,18 @@ class TokenizerManager:
...
@@ -1054,8 +1042,18 @@ class TokenizerManager:
)
)
async
with
self
.
model_update_lock
.
writer_lock
:
async
with
self
.
model_update_lock
.
writer_lock
:
# Generate new uniquely identifiable LoRARef object.
new_adapter
=
LoRARef
(
lora_name
=
obj
.
lora_name
,
lora_path
=
obj
.
lora_path
,
)
# Register the new adapter in the registry.
obj
.
lora_id
=
new_adapter
.
lora_id
result
=
(
await
self
.
update_lora_adapter_communicator
(
obj
))[
0
]
result
=
(
await
self
.
update_lora_adapter_communicator
(
obj
))[
0
]
self
.
loaded_lora_adapters
=
result
.
loaded_adapters
if
result
.
success
:
await
self
.
lora_registry
.
register
(
new_adapter
)
return
result
return
result
async
def
unload_lora_adapter
(
async
def
unload_lora_adapter
(
...
@@ -1069,6 +1067,10 @@ class TokenizerManager:
...
@@ -1069,6 +1067,10 @@ class TokenizerManager:
"LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
"LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
)
)
assert
(
obj
.
lora_name
is
not
None
),
"lora_name must be provided to unload LoRA adapter"
# TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
# TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
# with dp_size > 1.
# with dp_size > 1.
assert
(
assert
(
...
@@ -1080,8 +1082,9 @@ class TokenizerManager:
...
@@ -1080,8 +1082,9 @@ class TokenizerManager:
)
)
async
with
self
.
model_update_lock
.
writer_lock
:
async
with
self
.
model_update_lock
.
writer_lock
:
obj
.
lora_id
=
await
self
.
lora_registry
.
unregister
(
obj
.
lora_name
)
result
=
(
await
self
.
update_lora_adapter_communicator
(
obj
))[
0
]
result
=
(
await
self
.
update_lora_adapter_communicator
(
obj
))[
0
]
self
.
loaded_lora_adapters
=
result
.
loaded_adapters
return
result
return
result
async
def
get_weights_by_name
(
async
def
get_weights_by_name
(
...
@@ -1309,7 +1312,7 @@ class TokenizerManager:
...
@@ -1309,7 +1312,7 @@ class TokenizerManager:
filename
=
os
.
path
.
join
(
filename
=
os
.
path
.
join
(
self
.
crash_dump_folder
,
self
.
crash_dump_folder
,
os
.
getenv
(
"HOSTNAME"
,
None
),
os
.
getenv
(
"HOSTNAME"
,
None
),
f
'
crash_dump_
{
datetime
.
now
().
strftime
(
"
%Y-%m-%d_%H-%M-%S
"
)
}
.pkl
'
,
f
"
crash_dump_
{
datetime
.
now
().
strftime
(
'
%Y-%m-%d_%H-%M-%S
'
)
}
.pkl
"
,
)
)
os
.
makedirs
(
os
.
path
.
dirname
(
filename
),
exist_ok
=
True
)
os
.
makedirs
(
os
.
path
.
dirname
(
filename
),
exist_ok
=
True
)
...
...
python/sglang/srt/managers/tp_worker.py
View file @
8abd3e77
...
@@ -293,11 +293,9 @@ class TpModelWorker:
...
@@ -293,11 +293,9 @@ class TpModelWorker:
return
parameter
return
parameter
def
load_lora_adapter
(
self
,
recv_req
:
LoadLoRAAdapterReqInput
):
def
load_lora_adapter
(
self
,
recv_req
:
LoadLoRAAdapterReqInput
):
result
=
self
.
model_runner
.
load_lora_adapter
(
result
=
self
.
model_runner
.
load_lora_adapter
(
recv_req
.
to_ref
())
recv_req
.
lora_name
,
recv_req
.
lora_path
)
return
result
return
result
def
unload_lora_adapter
(
self
,
recv_req
:
UnloadLoRAAdapterReqInput
):
def
unload_lora_adapter
(
self
,
recv_req
:
UnloadLoRAAdapterReqInput
):
result
=
self
.
model_runner
.
unload_lora_adapter
(
recv_req
.
lora_name
)
result
=
self
.
model_runner
.
unload_lora_adapter
(
recv_req
.
to_ref
()
)
return
result
return
result
python/sglang/srt/model_executor/model_runner.py
View file @
8abd3e77
...
@@ -68,6 +68,7 @@ from sglang.srt.layers.sampler import Sampler
...
@@ -68,6 +68,7 @@ from sglang.srt.layers.sampler import Sampler
from
sglang.srt.layers.torchao_utils
import
apply_torchao_config_to_model
from
sglang.srt.layers.torchao_utils
import
apply_torchao_config_to_model
from
sglang.srt.layers.utils
import
is_sm100_supported
from
sglang.srt.layers.utils
import
is_sm100_supported
from
sglang.srt.lora.lora_manager
import
LoRAManager
from
sglang.srt.lora.lora_manager
import
LoRAManager
from
sglang.srt.lora.lora_registry
import
LoRARef
from
sglang.srt.managers.schedule_batch
import
(
from
sglang.srt.managers.schedule_batch
import
(
GLOBAL_SERVER_ARGS_KEYS
,
GLOBAL_SERVER_ARGS_KEYS
,
global_server_args_dict
,
global_server_args_dict
,
...
@@ -890,44 +891,38 @@ class ModelRunner:
...
@@ -890,44 +891,38 @@ class ModelRunner:
tp_rank
=
self
.
tp_rank
,
tp_rank
=
self
.
tp_rank
,
max_lora_rank
=
self
.
server_args
.
max_lora_rank
,
max_lora_rank
=
self
.
server_args
.
max_lora_rank
,
target_modules
=
self
.
server_args
.
lora_target_modules
,
target_modules
=
self
.
server_args
.
lora_target_modules
,
lora_paths
=
self
.
server_args
.
lora_paths
,
)
)
result
=
self
.
lora_manager
.
load_lora_adapters
(
self
.
server_args
.
lora_paths
or
{})
if
result
.
success
:
logger
.
info
(
f
"LoRA manager ready. Loaded LoRA adapters:
{
', '
.
join
(
result
.
loaded_adapters
)
}
"
)
else
:
raise
RuntimeError
(
f
"Failed to load LoRA adapters:
{
result
.
error_message
}
"
)
def
load_lora_adapter
(
self
,
lora_
name
:
str
,
lora_path
:
str
):
def
load_lora_adapter
(
self
,
lora_
ref
:
LoRARef
):
"""Load a new lora adapter from disk or huggingface."""
"""Load a new lora adapter from disk or huggingface."""
logger
.
info
(
logger
.
info
(
f
"LoRA adapter loading starts:
name=
{
lora_
name
}
, path=
{
lora_path
}
. "
f
"LoRA adapter loading starts:
{
lora_
ref
}
. "
f
"avail mem=
{
get_available_gpu_memory
(
self
.
device
,
self
.
gpu_id
):.
2
f
}
GB"
f
"avail mem=
{
get_available_gpu_memory
(
self
.
device
,
self
.
gpu_id
):.
2
f
}
GB"
)
)
result
=
self
.
lora_manager
.
load_lora_adapter
(
lora_
name
,
lora_path
)
result
=
self
.
lora_manager
.
load_lora_adapter
(
lora_
ref
)
logger
.
info
(
logger
.
info
(
f
"LoRA adapter loading completes:
name=
{
lora_
name
}
, path=
{
lora_path
}
. "
f
"LoRA adapter loading completes:
{
lora_
ref
}
. "
f
"avail mem=
{
get_available_gpu_memory
(
self
.
device
,
self
.
gpu_id
):.
2
f
}
GB"
f
"avail mem=
{
get_available_gpu_memory
(
self
.
device
,
self
.
gpu_id
):.
2
f
}
GB"
)
)
return
result
return
result
def
unload_lora_adapter
(
self
,
lora_
name
:
str
):
def
unload_lora_adapter
(
self
,
lora_
ref
:
LoRARef
):
"""Unload a lora adapter that was previously loaded during initialization or dynamic loading."""
"""Unload a lora adapter that was previously loaded during initialization or dynamic loading."""
logger
.
info
(
logger
.
info
(
f
"LoRA adapter unloading starts:
name=
{
lora_
name
}
. "
f
"LoRA adapter unloading starts:
{
lora_
ref
}
. "
f
"avail mem=
{
get_available_gpu_memory
(
self
.
device
,
self
.
gpu_id
):.
2
f
}
GB"
f
"avail mem=
{
get_available_gpu_memory
(
self
.
device
,
self
.
gpu_id
):.
2
f
}
GB"
)
)
result
=
self
.
lora_manager
.
unload_lora_adapter
(
lora_
name
)
result
=
self
.
lora_manager
.
unload_lora_adapter
(
lora_
ref
)
logger
.
info
(
logger
.
info
(
f
"LoRA adapter unloading completes:
name=
{
lora_
name
}
. "
f
"LoRA adapter unloading completes:
{
lora_
ref
}
. "
f
"avail mem=
{
get_available_gpu_memory
(
self
.
device
,
self
.
gpu_id
):.
2
f
}
GB"
f
"avail mem=
{
get_available_gpu_memory
(
self
.
device
,
self
.
gpu_id
):.
2
f
}
GB"
)
)
...
...
python/sglang/srt/server_args.py
View file @
8abd3e77
...
@@ -20,10 +20,10 @@ import logging
...
@@ -20,10 +20,10 @@ import logging
import
os
import
os
import
random
import
random
import
tempfile
import
tempfile
from
token
import
OP
from
typing
import
List
,
Literal
,
Optional
,
Union
from
typing
import
List
,
Literal
,
Optional
,
Union
from
sglang.srt.hf_transformers_utils
import
check_gguf_file
,
get_config
from
sglang.srt.hf_transformers_utils
import
check_gguf_file
,
get_config
from
sglang.srt.lora.lora_registry
import
LoRARef
from
sglang.srt.reasoning_parser
import
ReasoningParser
from
sglang.srt.reasoning_parser
import
ReasoningParser
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
LORA_TARGET_ALL_MODULES
,
LORA_TARGET_ALL_MODULES
,
...
@@ -145,7 +145,7 @@ class ServerArgs:
...
@@ -145,7 +145,7 @@ class ServerArgs:
enable_lora
:
Optional
[
bool
]
=
None
enable_lora
:
Optional
[
bool
]
=
None
max_lora_rank
:
Optional
[
int
]
=
None
max_lora_rank
:
Optional
[
int
]
=
None
lora_target_modules
:
Optional
[
Union
[
set
[
str
],
List
[
str
]]]
=
None
lora_target_modules
:
Optional
[
Union
[
set
[
str
],
List
[
str
]]]
=
None
lora_paths
:
Optional
[
Union
[
dict
[
str
,
str
],
List
[
str
]]]
=
None
lora_paths
:
Optional
[
Union
[
dict
[
str
,
str
],
dict
[
str
,
LoRARef
],
List
[
str
]]]
=
None
max_loras_per_batch
:
int
=
8
max_loras_per_batch
:
int
=
8
lora_backend
:
str
=
"triton"
lora_backend
:
str
=
"triton"
...
@@ -1843,9 +1843,24 @@ class ServerArgs:
...
@@ -1843,9 +1843,24 @@ class ServerArgs:
for
lora_path
in
lora_paths
:
for
lora_path
in
lora_paths
:
if
"="
in
lora_path
:
if
"="
in
lora_path
:
name
,
path
=
lora_path
.
split
(
"="
,
1
)
name
,
path
=
lora_path
.
split
(
"="
,
1
)
self
.
lora_paths
[
name
]
=
path
self
.
lora_paths
[
name
]
=
LoRARef
(
lora_name
=
name
,
lora_path
=
path
)
else
:
else
:
self
.
lora_paths
[
lora_path
]
=
lora_path
self
.
lora_paths
[
lora_path
]
=
LoRARef
(
lora_name
=
lora_path
,
lora_path
=
lora_path
,
)
elif
isinstance
(
self
.
lora_paths
,
dict
):
self
.
lora_paths
=
{
k
:
LoRARef
(
lora_name
=
k
,
lora_path
=
v
)
for
k
,
v
in
self
.
lora_paths
.
items
()
}
elif
self
.
lora_paths
is
None
:
self
.
lora_paths
=
{}
else
:
raise
ValueError
(
f
"Invalid type for --lora-paths:
{
type
(
self
.
lora_paths
)
}
. "
"Expected a list or a dictionary."
)
# Expand target modules
# Expand target modules
if
self
.
lora_target_modules
:
if
self
.
lora_target_modules
:
...
...
test/srt/models/lora/test_lora_eviction.py
View file @
8abd3e77
...
@@ -12,6 +12,7 @@
...
@@ -12,6 +12,7 @@
# limitations under the License.
# limitations under the License.
# ==============================================================================
# ==============================================================================
import
contextlib
import
multiprocessing
as
mp
import
multiprocessing
as
mp
import
unittest
import
unittest
from
typing
import
Dict
,
List
,
Tuple
from
typing
import
Dict
,
List
,
Tuple
...
@@ -39,6 +40,16 @@ ADAPTERS = [
...
@@ -39,6 +40,16 @@ ADAPTERS = [
BASE_MODEL
=
"meta-llama/Meta-Llama-3.1-8B-Instruct"
BASE_MODEL
=
"meta-llama/Meta-Llama-3.1-8B-Instruct"
@
contextlib
.
contextmanager
def
dynamically_loaded_adapter
(
runner
,
lora_path
:
str
,
lora_name
:
str
):
"""A context manager to load and automatically unload a LoRA adapter."""
try
:
runner
.
load_lora_adapter
(
lora_name
=
lora_name
,
lora_path
=
lora_path
)
yield
finally
:
runner
.
unload_lora_adapter
(
lora_name
=
lora_name
)
class
TestLoRAEviction
(
CustomTestCase
):
class
TestLoRAEviction
(
CustomTestCase
):
def
test_lora_eviction_with_different_target_modules
(
self
):
def
test_lora_eviction_with_different_target_modules
(
self
):
"""
"""
...
@@ -51,55 +62,80 @@ class TestLoRAEviction(CustomTestCase):
...
@@ -51,55 +62,80 @@ class TestLoRAEviction(CustomTestCase):
self
.
_run_test
(
ADAPTERS
,
output_history
,
reverse
=
False
)
self
.
_run_test
(
ADAPTERS
,
output_history
,
reverse
=
False
)
self
.
_run_test
(
ADAPTERS
,
output_history
,
reverse
=
True
)
self
.
_run_test
(
ADAPTERS
,
output_history
,
reverse
=
True
)
def
test_lora_eviction_with_reused_lora_name
(
self
):
"""
Test LoRA eviction with reused LoRA names.
This test runs inference against two LoRA adapters with the same name to ensure that the eviction behavior
works correctly when reusing LoRA names.
"""
output_history
=
{}
self
.
_run_test
(
ADAPTERS
,
output_history
,
reuse_lora_name
=
True
,
repeat
=
1
)
self
.
_run_test
(
ADAPTERS
,
output_history
,
reuse_lora_name
=
False
,
repeat
=
1
)
def
_run_test
(
def
_run_test
(
self
,
self
,
lora_paths
:
List
[
str
],
lora_paths
:
List
[
str
],
output_history
:
Dict
[
Tuple
[
str
,
str
],
str
],
output_history
:
Dict
[
Tuple
[
str
,
str
],
str
],
reverse
:
bool
,
reverse
:
bool
=
False
,
repeat
:
int
=
2
,
repeat
:
int
=
2
,
reuse_lora_name
:
bool
=
False
,
):
):
REUSED_LORA_NAME
=
"lora"
max_new_tokens
=
256
max_new_tokens
=
256
backend
=
"triton"
backend
=
"triton"
torch_dtype
=
torch
.
float16
torch_dtype
=
torch
.
float16
base_path
=
BASE_MODEL
base_path
=
BASE_MODEL
assert
len
(
lora_paths
)
>=
2
assert
len
(
lora_paths
)
>=
2
initial_lora_paths
=
lora_paths
if
not
reuse_lora_name
else
None
# Initialize runners
# Initialize runners
with
SRTRunner
(
with
SRTRunner
(
base_path
,
base_path
,
torch_dtype
=
torch_dtype
,
torch_dtype
=
torch_dtype
,
model_type
=
"generation"
,
model_type
=
"generation"
,
lora_paths
=
lora_paths
,
lora_paths
=
initial_
lora_paths
,
max_loras_per_batch
=
1
,
max_loras_per_batch
=
1
,
lora_backend
=
backend
,
lora_backend
=
backend
,
disable_radix_cache
=
True
,
disable_radix_cache
=
True
,
enable_lora
=
True
,
max_lora_rank
=
256
,
lora_target_modules
=
[
"all"
],
)
as
srt_runner
:
)
as
srt_runner
:
adapter_sequence
=
lora_paths
if
not
reverse
else
lora_paths
[::
-
1
]
adapter_sequence
=
lora_paths
if
not
reverse
else
lora_paths
[::
-
1
]
for
i
in
range
(
repeat
):
for
i
in
range
(
repeat
):
for
j
,
adapter
in
enumerate
(
adapter_sequence
):
for
j
,
lora_path
in
enumerate
(
adapter_sequence
):
print
(
print
(
f
"
\n
========== Testing LoRA eviction with adapter '
{
adapter
}
' (#
{
j
+
1
}
/
{
len
(
adapter_sequence
)
}
), reversed:
{
reverse
}
, repeat:
{
i
+
1
}
/
{
repeat
}
---"
f
"
\n
========== Testing LoRA eviction with adapter '
{
lora_path
}
' (#
{
j
+
1
}
/
{
len
(
adapter_sequence
)
}
), reuse_lora_name:
{
reuse_lora_name
}
, reversed:
{
reverse
}
, repeat:
{
i
+
1
}
/
{
repeat
}
---"
)
lora_name
=
REUSED_LORA_NAME
if
reuse_lora_name
else
lora_path
context
=
(
dynamically_loaded_adapter
(
srt_runner
,
lora_path
,
lora_name
)
if
reuse_lora_name
else
contextlib
.
nullcontext
()
)
)
for
prompt
in
PROMPTS
:
with
context
:
print
(
"
\n
prompt:
\n
"
,
prompt
)
for
prompt
in
PROMPTS
:
srt_outputs
=
srt_runner
.
forward
(
print
(
"
\n
prompt:
\n
"
,
prompt
)
[
prompt
],
srt_outputs
=
srt_runner
.
forward
(
max_new_tokens
=
max_new_tokens
,
[
prompt
],
lora_paths
=
[
adapter
],
max_new_tokens
=
max_new_tokens
,
)
lora_paths
=
[
lora_name
],
output
=
srt_outputs
.
output_strs
[
0
].
strip
()
print
(
"
\n
output:
\n
"
,
output
)
prev_output
=
output_history
.
get
((
adapter
,
prompt
))
if
prev_output
is
not
None
:
self
.
assertEqual
(
prev_output
,
output
,
f
"Output mismatch for adapter
{
adapter
}
and prompt '
{
prompt
}
' on repeat
{
j
+
1
}
, previous: '
{
prev_output
}
', current: '
{
output
}
'."
,
)
)
else
:
output
=
srt_outputs
.
output_strs
[
0
].
strip
()
output_history
[(
adapter
,
prompt
)]
=
output
print
(
"
\n
output:
\n
"
,
output
)
prev_output
=
output_history
.
get
((
lora_path
,
prompt
))
if
prev_output
is
not
None
:
self
.
assertEqual
(
prev_output
,
output
,
f
"Output mismatch for adapter
{
lora_path
}
and prompt '
{
prompt
}
' on repeat
{
j
+
1
}
, previous: '
{
prev_output
}
', current: '
{
output
}
'."
,
)
else
:
output_history
[(
lora_path
,
prompt
)]
=
output
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
test/srt/run_suite.py
View file @
8abd3e77
...
@@ -14,7 +14,7 @@ class TestFile:
...
@@ -14,7 +14,7 @@ class TestFile:
suites
=
{
suites
=
{
"per-commit"
:
[
"per-commit"
:
[
TestFile
(
"models/lora/test_lora.py"
,
200
),
TestFile
(
"models/lora/test_lora.py"
,
200
),
TestFile
(
"models/lora/test_lora_eviction.py"
,
1
20
),
TestFile
(
"models/lora/test_lora_eviction.py"
,
2
0
0
),
TestFile
(
"models/lora/test_lora_backend.py"
,
99
),
TestFile
(
"models/lora/test_lora_backend.py"
,
99
),
TestFile
(
"models/lora/test_multi_lora_backend.py"
,
60
),
TestFile
(
"models/lora/test_multi_lora_backend.py"
,
60
),
TestFile
(
"models/lora/test_lora_cuda_graph.py"
,
250
),
TestFile
(
"models/lora/test_lora_cuda_graph.py"
,
250
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment