Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
df906455
Unverified
Commit
df906455
authored
Jul 27, 2025
by
Lifu Huang
Committed by
GitHub
Jul 27, 2025
Browse files
Support overlapped lora updates (#8213)
parent
95217a9b
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
204 additions
and
35 deletions
+204
-35
python/sglang/srt/lora/lora_registry.py
python/sglang/srt/lora/lora_registry.py
+92
-28
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+24
-5
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+87
-0
test/srt/test_bench_serving.py
test/srt/test_bench_serving.py
+1
-2
No files found.
python/sglang/srt/lora/lora_registry.py
View file @
df906455
...
...
@@ -14,10 +14,14 @@
import
asyncio
from
collections
import
defaultdict
from
dataclasses
import
dataclass
,
field
,
fields
from
typing
import
Dict
,
List
,
Optional
,
Union
from
uuid
import
uuid4
from
sglang.srt.aio_rwlock
import
RWLock
from
sglang.srt.utils
import
ConcurrentCounter
@
dataclass
(
frozen
=
True
)
class
LoRARef
:
...
...
@@ -48,10 +52,11 @@ class LoRARef:
class
LoRARegistry
:
"""
The central registry to keep track of available LoRA adapters.
The central registry to keep track of available LoRA adapters
and ongoing LoRA requests
.
TODO (lifuhuang): This registry is intended as the foundation for overlapped lora update. We decided
to keep it in a separate PR to keep code review simple and to unblock the radix cache work.
The `LoRARegistry` resides in the tokenizer manager process and acts as the single source of truth for all
available LoRA adapters. It supports concurrent inference and dynamic adapter updates through a two-phase
update / eventual consistency model between the tokenizer manager process and the scheduler processes.
"""
def
__init__
(
self
,
lora_paths
:
Optional
[
Dict
[
str
,
LoRARef
]]
=
None
):
...
...
@@ -62,8 +67,19 @@ class LoRARegistry:
"Please file an issue if you see this error."
)
# A read-write lock to ensure adapters loading / unloading operations are exclusive.
# Please note that the counter increment/decrement operations are not synchronized through this
# lock, as they are designed to be non-blocking and can be performed concurrently.
self
.
_registry_lock
=
RWLock
()
# A dictionary to hold LoRARef objects, mapping from LoRA name to LoRARef.
self
.
_registry
:
Dict
[
str
,
LoRARef
]
=
dict
(
lora_paths
or
{})
self
.
_registry
:
Dict
[
str
,
LoRARef
]
=
{}
# Counters for ongoing requests, mapping from LoRA ID to ConcurrentCounter.
self
.
_counters
:
Dict
[
str
,
ConcurrentCounter
]
=
{}
# Initialize the registry with provided LoRA paths, if present.
if
lora_paths
:
for
lora_ref
in
lora_paths
.
values
():
self
.
_register_adapter
(
lora_ref
)
async
def
register
(
self
,
lora_ref
:
LoRARef
):
"""
...
...
@@ -72,11 +88,8 @@ class LoRARegistry:
Args:
lora_ref (LoRARef): The LoRARef object to register.
"""
if
lora_ref
.
lora_name
in
self
.
_registry
:
raise
ValueError
(
f
"LoRA with name
{
lora_ref
.
lora_name
}
already exists. Loaded LoRAs:
{
self
.
_registry
.
keys
()
}
"
)
self
.
_registry
[
lora_ref
.
lora_name
]
=
lora_ref
async
with
self
.
_registry_lock
.
writer_lock
:
self
.
_register_adapter
(
lora_ref
)
async
def
unregister
(
self
,
lora_name
:
str
)
->
str
:
"""
...
...
@@ -85,12 +98,14 @@ class LoRARegistry:
Args:
lora_name (str): The name of the LoRA model to unregister.
"""
lora_ref
=
self
.
_registry
.
get
(
lora_name
,
None
)
if
lora_ref
is
None
:
raise
ValueError
(
f
"LoRA with name
{
lora_name
}
does not exist. Loaded LoRAs:
{
self
.
_registry
.
keys
()
}
"
)
del
self
.
_registry
[
lora_name
]
async
with
self
.
_registry_lock
.
writer_lock
:
lora_ref
=
self
.
_registry
.
get
(
lora_name
,
None
)
if
lora_ref
is
None
:
raise
ValueError
(
f
"LoRA with name
{
lora_name
}
does not exist. Loaded LoRAs:
{
self
.
_registry
.
keys
()
}
"
)
del
self
.
_registry
[
lora_name
]
del
self
.
_counters
[
lora_ref
.
lora_id
]
return
lora_ref
.
lora_id
...
...
@@ -98,27 +113,76 @@ class LoRARegistry:
"""
Queries registry for LoRA IDs based on LoRA names and start tracking the usage of the corresponding LoRA adapters
by incrementing its counter.
TODO (lifuhuang): currently it only queries the registry and does not track the usage of LoRA adapters.
"""
async
def
_acquire_single
(
name
:
str
)
->
str
:
def
_lookup
(
name
:
str
)
->
str
:
lora_ref
=
self
.
_registry
.
get
(
name
,
None
)
if
lora_ref
is
None
:
raise
ValueError
(
f
"The following requested LoRA adapters are not loaded:
{
name
}
\n
"
f
"Loaded adapters:
{
self
.
_registry
.
keys
()
}
."
)
# await self._counters[lora_ref.lora_id].increment()
return
lora_ref
.
lora_id
if
isinstance
(
lora_name
,
str
):
lora_id
=
await
_acquire_single
(
lora_name
)
return
lora_id
elif
isinstance
(
lora_name
,
list
):
lora_ids
=
await
asyncio
.
gather
(
*
[
_acquire_single
(
name
)
for
name
in
lora_name
]
async
with
self
.
_registry_lock
.
reader_lock
:
if
isinstance
(
lora_name
,
str
):
lora_id
=
_lookup
(
lora_name
)
await
self
.
_counters
[
lora_id
].
increment
(
notify_all
=
False
)
return
lora_id
elif
isinstance
(
lora_name
,
list
):
lora_ids
=
[
_lookup
(
name
)
for
name
in
lora_name
]
# Increment the counters only after all IDs are looked up.
await
asyncio
.
gather
(
*
[
self
.
_counters
[
id
].
increment
(
notify_all
=
False
)
for
id
in
lora_ids
]
)
return
lora_ids
else
:
raise
TypeError
(
"lora_name must be either a string or a list of strings."
)
async
def
release
(
self
,
lora_id
:
Union
[
str
,
List
[
str
]]):
"""
Decrements the usage counter for a LoRA adapter, indicating that it is no longer in use.
"""
async
with
self
.
_registry_lock
.
reader_lock
:
if
isinstance
(
lora_id
,
str
):
await
self
.
_counters
[
lora_id
].
decrement
()
elif
isinstance
(
lora_id
,
list
):
await
asyncio
.
gather
(
*
[
self
.
_counters
[
id
].
decrement
()
for
id
in
lora_id
]
)
else
:
raise
TypeError
(
"lora_id must be either a string or a list of strings."
)
async
def
wait_for_unload
(
self
,
lora_id
:
str
):
"""
Waits until the usage counter for a LoRA adapter reaches zero, indicating that it is no longer in use.
This is useful for ensuring that a LoRA adapter can be safely unloaded.
This method itself is not synchronized, which is safe because it should only be called during LoRA unloading,
which itself is guaranteed to be sequential.
"""
assert
(
lora_id
not
in
self
.
_registry
),
"wait_for_unload should only be called after the LoRA adapter has been unregistered. "
counter
=
self
.
_counters
.
get
(
lora_id
)
if
counter
:
# Wait until no requests are using this LoRA adapter.
await
counter
.
wait_for_zero
()
del
self
.
_counters
[
lora_id
]
def
_register_adapter
(
self
,
lora_ref
:
LoRARef
):
"""
Internal helper method to register a LoRA adapter.
"""
if
lora_ref
.
lora_name
in
self
.
_registry
:
raise
ValueError
(
f
"LoRA with name
{
lora_ref
.
lora_name
}
already exists. Loaded LoRAs:
{
self
.
_registry
.
keys
()
}
"
)
return
lora_
ids
el
se
:
raise
TypeError
(
"lora_name must be either a string or a list of strings."
)
self
.
_registry
[
lora_ref
.
lora_name
]
=
lora_
ref
s
el
f
.
_counters
[
lora_ref
.
lora_id
]
=
ConcurrentCounter
()
return
lora_ref
python/sglang/srt/managers/tokenizer_manager.py
View file @
df906455
...
...
@@ -282,6 +282,11 @@ class TokenizerManager:
None
)
# Lock to serialize LoRA update operations.
# Please note that, unlike `model_update_lock`, this does not block inference, allowing
# LoRA updates and inference to overlap.
self
.
lora_update_lock
=
asyncio
.
Lock
()
# For pd disaggregtion
self
.
disaggregation_mode
=
DisaggregationMode
(
self
.
server_args
.
disaggregation_mode
...
...
@@ -537,7 +542,8 @@ class TokenizerManager:
mm_inputs
=
None
if
self
.
server_args
.
enable_lora
and
obj
.
lora_path
:
# Replace the user-friendly LoRA names in `lora_path` with their corresponding unique LoRA IDs.
# Start tracking ongoing requests for LoRA adapters and replace the user-friendly LoRA names in
# `lora_path` with their corresponding unique LoRA IDs, as required for internal processing.
obj
.
lora_path
=
await
self
.
lora_registry
.
acquire
(
obj
.
lora_path
)
self
.
_validate_one_request
(
obj
,
input_ids
)
...
...
@@ -747,6 +753,10 @@ class TokenizerManager:
msg
=
f
"Finish: obj=
{
dataclass_to_string_truncated
(
obj
,
max_length
,
skip_names
=
skip_names
)
}
, out=
{
dataclass_to_string_truncated
(
out
,
max_length
,
skip_names
=
out_skip_names
)
}
"
logger
.
info
(
msg
)
# Mark ongoing LoRA request as finished.
if
self
.
server_args
.
enable_lora
and
obj
.
lora_path
:
await
self
.
lora_registry
.
release
(
obj
.
lora_path
)
# Check if this was an abort/error created by scheduler
if
isinstance
(
out
[
"meta_info"
].
get
(
"finish_reason"
),
dict
):
finish_reason
=
out
[
"meta_info"
][
"finish_reason"
]
...
...
@@ -1053,16 +1063,18 @@ class TokenizerManager:
obj
.
lora_path
,
)
async
with
self
.
model
_update_lock
.
writer_lock
:
async
with
self
.
lora
_update_lock
:
# Generate new uniquely identifiable LoRARef object.
new_adapter
=
LoRARef
(
lora_name
=
obj
.
lora_name
,
lora_path
=
obj
.
lora_path
,
)
#
Regist
er the
new adapter in the registry
.
#
Trigg
er the
actual loading operation at the backend processes
.
obj
.
lora_id
=
new_adapter
.
lora_id
result
=
(
await
self
.
update_lora_adapter_communicator
(
obj
))[
0
]
# Register the LoRA adapter only after loading is successful.
if
result
.
success
:
await
self
.
lora_registry
.
register
(
new_adapter
)
...
...
@@ -1093,8 +1105,15 @@ class TokenizerManager:
obj
.
lora_name
,
)
async
with
self
.
model_update_lock
.
writer_lock
:
obj
.
lora_id
=
await
self
.
lora_registry
.
unregister
(
obj
.
lora_name
)
async
with
self
.
lora_update_lock
:
# Unregister the LoRA adapter from the registry to stop new requests for this adapter
# from being started.
lora_id
=
await
self
.
lora_registry
.
unregister
(
obj
.
lora_name
)
obj
.
lora_id
=
lora_id
# Initiate the actual unloading operation at the backend processes only after all
# ongoing requests using this LoRA adapter are finished.
await
self
.
lora_registry
.
wait_for_unload
(
lora_id
)
result
=
(
await
self
.
update_lora_adapter_communicator
(
obj
))[
0
]
return
result
...
...
python/sglang/srt/utils.py
View file @
df906455
...
...
@@ -15,6 +15,7 @@
from
__future__
import
annotations
import
asyncio
import
builtins
import
ctypes
import
dataclasses
...
...
@@ -2862,3 +2863,89 @@ SUPPORTED_LORA_TARGET_MODULES = [
]
LORA_TARGET_ALL_MODULES
=
"all"
class
ConcurrentCounter
:
"""
An asynchronous counter for managing concurrent tasks that need
coordinated increments, decrements, and waiting until the count reaches zero.
This class is useful for scenarios like tracking the number of in-flight tasks
and waiting for them to complete.
"""
def
__init__
(
self
,
initial
:
int
=
0
):
"""
Initialize the counter with an optional initial value.
Args:
initial (int): The initial value of the counter. Default is 0.
"""
self
.
_count
=
initial
self
.
_condition
=
asyncio
.
Condition
()
def
value
(
self
)
->
int
:
"""
Return the current value of the counter.
Note:
This method is not synchronized. It may return a stale value
if other coroutines are concurrently modifying the counter.
Returns:
int: The current counter value.
"""
return
self
.
_count
def
__repr__
(
self
)
->
str
:
"""Return an informative string representation of the counter."""
return
f
"<ConcurrentCounter value=
{
self
.
value
()
}
>"
async
def
increment
(
self
,
n
:
int
=
1
,
notify_all
:
bool
=
True
):
"""
Atomically increment the counter by a given amount and notify all waiters.
Args:
n (int): The amount to increment the counter by. Default is 1.
notify_all (bool): Whether to notify all waiters after incrementing. Default is True.
"""
async
with
self
.
_condition
:
self
.
_count
+=
n
if
notify_all
:
self
.
_condition
.
notify_all
()
async
def
decrement
(
self
,
n
:
int
=
1
,
notify_all
:
bool
=
True
):
"""
Atomically decrement the counter by a given amount and notify all waiters.
Args:
n (int): The amount to decrement the counter by. Default is 1.
notify_all (bool): Whether to notify all waiters after decrementing. Default is True.
"""
async
with
self
.
_condition
:
self
.
_count
-=
n
if
notify_all
:
self
.
_condition
.
notify_all
()
async
def
wait_for
(
self
,
condition
:
Callable
[[
int
],
bool
]):
"""
Asynchronously wait until the counter satisfies a given condition.
This suspends the calling coroutine without blocking the thread, allowing
other tasks to run while waiting. When the condition is met, the coroutine resumes.
Args:
condition (Callable[[int], bool]): A function that takes the current counter value
and returns True when the condition is satisfied.
"""
async
with
self
.
_condition
:
await
self
.
_condition
.
wait_for
(
lambda
:
condition
(
self
.
_count
))
async
def
wait_for_zero
(
self
):
"""
Asynchronously wait until the counter reaches zero.
This suspends the calling coroutine without blocking the thread, allowing
other tasks to run while waiting. When the counter becomes zero, the coroutine resumes.
"""
self
.
wait_for
(
lambda
count
:
count
==
0
)
test/srt/test_bench_serving.py
View file @
df906455
...
...
@@ -231,8 +231,7 @@ class TestBenchServing(CustomTestCase):
f
"median_ttft_ms:
{
res
[
'median_ttft_ms'
]:.
2
f
}
ms
\n
"
)
self
.
assertLess
(
res
[
"median_e2e_latency_ms"
],
4000
)
# TODO (lifuhuang): This will be fixed by the overlapped LoRA update in a separate PR.
self
.
assertLess
(
res
[
"median_ttft_ms"
],
1600
)
self
.
assertLess
(
res
[
"median_ttft_ms"
],
80
)
def
_run_lora_latency_test
(
self
,
enable_background_task
:
bool
):
"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment