Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
49538d11
Unverified
Commit
49538d11
authored
Jun 27, 2025
by
Lifu Huang
Committed by
GitHub
Jun 27, 2025
Browse files
Support dynamic LoRA loading / unloading in engine/server API (#7446)
parent
cfe2edac
Changes
14
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
948 additions
and
30 deletions
+948
-30
python/sglang/srt/entrypoints/EngineBase.py
python/sglang/srt/entrypoints/EngineBase.py
+8
-0
python/sglang/srt/entrypoints/engine.py
python/sglang/srt/entrypoints/engine.py
+25
-0
python/sglang/srt/entrypoints/http_server.py
python/sglang/srt/entrypoints/http_server.py
+36
-0
python/sglang/srt/lora/lora.py
python/sglang/srt/lora/lora.py
+4
-5
python/sglang/srt/lora/lora_manager.py
python/sglang/srt/lora/lora_manager.py
+73
-20
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+25
-1
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+36
-0
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+55
-0
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+12
-0
python/sglang/srt/managers/tp_worker_overlap_thread.py
python/sglang/srt/managers/tp_worker_overlap_thread.py
+8
-0
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+41
-3
python/sglang/test/runners.py
python/sglang/test/runners.py
+8
-1
test/srt/models/lora/test_lora_update.py
test/srt/models/lora/test_lora_update.py
+616
-0
test/srt/run_suite.py
test/srt/run_suite.py
+1
-0
No files found.
python/sglang/srt/entrypoints/EngineBase.py
View file @
49538d11
...
@@ -48,6 +48,14 @@ class EngineBase(ABC):
...
@@ -48,6 +48,14 @@ class EngineBase(ABC):
"""Update model weights with in-memory tensor data."""
"""Update model weights with in-memory tensor data."""
pass
pass
def
load_lora_adapter
(
self
,
lora_name
:
str
,
lora_path
:
str
):
"""Load a new LoRA adapter without re-launching the engine."""
pass
def
unload_lora_adapter
(
self
,
lora_name
:
str
):
"""Unload a LoRA adapter without re-launching the engine."""
pass
@
abstractmethod
@
abstractmethod
def
release_memory_occupation
(
self
):
def
release_memory_occupation
(
self
):
"""Release GPU memory occupation temporarily."""
"""Release GPU memory occupation temporarily."""
...
...
python/sglang/srt/entrypoints/engine.py
View file @
49538d11
...
@@ -48,10 +48,12 @@ from sglang.srt.managers.io_struct import (
...
@@ -48,10 +48,12 @@ from sglang.srt.managers.io_struct import (
GetWeightsByNameReqInput
,
GetWeightsByNameReqInput
,
ImageDataItem
,
ImageDataItem
,
InitWeightsUpdateGroupReqInput
,
InitWeightsUpdateGroupReqInput
,
LoadLoRAAdapterReqInput
,
ReleaseMemoryOccupationReqInput
,
ReleaseMemoryOccupationReqInput
,
ResumeMemoryOccupationReqInput
,
ResumeMemoryOccupationReqInput
,
RpcReqInput
,
RpcReqInput
,
RpcReqOutput
,
RpcReqOutput
,
UnloadLoRAAdapterReqInput
,
UpdateWeightFromDiskReqInput
,
UpdateWeightFromDiskReqInput
,
UpdateWeightsFromDistributedReqInput
,
UpdateWeightsFromDistributedReqInput
,
UpdateWeightsFromTensorReqInput
,
UpdateWeightsFromTensorReqInput
,
...
@@ -478,6 +480,29 @@ class Engine(EngineBase):
...
@@ -478,6 +480,29 @@ class Engine(EngineBase):
self
.
tokenizer_manager
.
get_weights_by_name
(
obj
,
None
)
self
.
tokenizer_manager
.
get_weights_by_name
(
obj
,
None
)
)
)
def
load_lora_adapter
(
self
,
lora_name
:
str
,
lora_path
:
str
):
"""Load a new LoRA adapter without re-launching the engine."""
obj
=
LoadLoRAAdapterReqInput
(
lora_name
=
lora_name
,
lora_path
=
lora_path
,
)
loop
=
asyncio
.
get_event_loop
()
return
loop
.
run_until_complete
(
self
.
tokenizer_manager
.
load_lora_adapter
(
obj
,
None
)
)
def
unload_lora_adapter
(
self
,
lora_name
:
str
):
"""Unload a LoRA adapter without re-launching the engine."""
obj
=
UnloadLoRAAdapterReqInput
(
lora_name
=
lora_name
)
loop
=
asyncio
.
get_event_loop
()
return
loop
.
run_until_complete
(
self
.
tokenizer_manager
.
unload_lora_adapter
(
obj
,
None
)
)
def
release_memory_occupation
(
self
,
tags
:
Optional
[
List
[
str
]]
=
None
):
def
release_memory_occupation
(
self
,
tags
:
Optional
[
List
[
str
]]
=
None
):
obj
=
ReleaseMemoryOccupationReqInput
(
tags
=
tags
)
obj
=
ReleaseMemoryOccupationReqInput
(
tags
=
tags
)
loop
=
asyncio
.
get_event_loop
()
loop
=
asyncio
.
get_event_loop
()
...
...
python/sglang/srt/entrypoints/http_server.py
View file @
49538d11
...
@@ -72,6 +72,7 @@ from sglang.srt.managers.io_struct import (
...
@@ -72,6 +72,7 @@ from sglang.srt.managers.io_struct import (
GenerateReqInput
,
GenerateReqInput
,
GetWeightsByNameReqInput
,
GetWeightsByNameReqInput
,
InitWeightsUpdateGroupReqInput
,
InitWeightsUpdateGroupReqInput
,
LoadLoRAAdapterReqInput
,
OpenSessionReqInput
,
OpenSessionReqInput
,
ParseFunctionCallReq
,
ParseFunctionCallReq
,
ProfileReqInput
,
ProfileReqInput
,
...
@@ -80,6 +81,7 @@ from sglang.srt.managers.io_struct import (
...
@@ -80,6 +81,7 @@ from sglang.srt.managers.io_struct import (
SeparateReasoningReqInput
,
SeparateReasoningReqInput
,
SetInternalStateReq
,
SetInternalStateReq
,
SlowDownReqInput
,
SlowDownReqInput
,
UnloadLoRAAdapterReqInput
,
UpdateWeightFromDiskReqInput
,
UpdateWeightFromDiskReqInput
,
UpdateWeightsFromDistributedReqInput
,
UpdateWeightsFromDistributedReqInput
,
UpdateWeightsFromTensorReqInput
,
UpdateWeightsFromTensorReqInput
,
...
@@ -595,6 +597,40 @@ async def slow_down(obj: SlowDownReqInput, request: Request):
...
@@ -595,6 +597,40 @@ async def slow_down(obj: SlowDownReqInput, request: Request):
return
_create_error_response
(
e
)
return
_create_error_response
(
e
)
@
app
.
api_route
(
"/load_lora_adapter"
,
methods
=
[
"POST"
])
async
def
load_lora_adapter
(
obj
:
LoadLoRAAdapterReqInput
,
request
:
Request
):
"""Load a new LoRA adapter without re-launching the server."""
result
=
await
_global_state
.
tokenizer_manager
.
load_lora_adapter
(
obj
,
request
)
if
result
.
success
:
return
ORJSONResponse
(
result
,
status_code
=
HTTPStatus
.
OK
,
)
else
:
return
ORJSONResponse
(
result
,
status_code
=
HTTPStatus
.
BAD_REQUEST
,
)
@
app
.
api_route
(
"/unload_lora_adapter"
,
methods
=
[
"POST"
])
async
def
unload_lora_adapter
(
obj
:
UnloadLoRAAdapterReqInput
,
request
:
Request
):
"""Load a new LoRA adapter without re-launching the server."""
result
=
await
_global_state
.
tokenizer_manager
.
unload_lora_adapter
(
obj
,
request
)
if
result
.
success
:
return
ORJSONResponse
(
result
,
status_code
=
HTTPStatus
.
OK
,
)
else
:
return
ORJSONResponse
(
result
,
status_code
=
HTTPStatus
.
BAD_REQUEST
,
)
@
app
.
api_route
(
"/open_session"
,
methods
=
[
"GET"
,
"POST"
])
@
app
.
api_route
(
"/open_session"
,
methods
=
[
"GET"
,
"POST"
])
async
def
open_session
(
obj
:
OpenSessionReqInput
,
request
:
Request
):
async
def
open_session
(
obj
:
OpenSessionReqInput
,
request
:
Request
):
"""Open a session, and return its unique session id."""
"""Open a session, and return its unique session id."""
...
...
python/sglang/srt/lora/lora.py
View file @
49538d11
...
@@ -65,7 +65,7 @@ class LoRAAdapter(nn.Module):
...
@@ -65,7 +65,7 @@ class LoRAAdapter(nn.Module):
self
.
layers
:
List
[
LoRALayer
]
=
nn
.
ModuleList
(
self
.
layers
:
List
[
LoRALayer
]
=
nn
.
ModuleList
(
[
[
LoRALayer
(
config
,
base_hf_config
)
LoRALayer
(
config
,
base_hf_config
)
for
i
in
range
(
base_hf_config
.
num_hidden_layers
)
for
_
in
range
(
base_hf_config
.
num_hidden_layers
)
]
]
)
)
...
@@ -88,10 +88,9 @@ class LoRAAdapter(nn.Module):
...
@@ -88,10 +88,9 @@ class LoRAAdapter(nn.Module):
else
:
else
:
self
.
weights
[
name
]
=
loaded_weight
.
cpu
()
self
.
weights
[
name
]
=
loaded_weight
.
cpu
()
# stack kv_proj and gate_up_proj
# normalize kv_proj and gate_up_proj
for
i
in
range
(
self
.
base_hf_config
.
num_hidden_layers
):
for
layer
in
self
.
layers
:
layer
=
self
.
layers
[
i
]
weight_names
=
list
(
layer
.
weights
.
keys
())
weight_names
=
[
name
for
name
,
_
in
layer
.
weights
.
items
()]
self
.
normalize_qkv_proj
(
weight_names
,
layer
.
weights
)
self
.
normalize_qkv_proj
(
weight_names
,
layer
.
weights
)
self
.
normalize_gate_up_proj
(
weight_names
,
layer
.
weights
)
self
.
normalize_gate_up_proj
(
weight_names
,
layer
.
weights
)
...
...
python/sglang/srt/lora/lora_manager.py
View file @
49538d11
...
@@ -35,6 +35,7 @@ from sglang.srt.lora.utils import (
...
@@ -35,6 +35,7 @@ from sglang.srt.lora.utils import (
get_normalized_lora_weight_names
,
get_normalized_lora_weight_names
,
get_weight_name
,
get_weight_name
,
)
)
from
sglang.srt.managers.io_struct
import
LoRAUpdateResult
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.utils
import
replace_submodule
from
sglang.srt.utils
import
replace_submodule
...
@@ -98,44 +99,96 @@ class LoRAManager:
...
@@ -98,44 +99,96 @@ class LoRAManager:
],
],
)
)
def
load_lora_adapters
(
self
,
lora_paths
:
Dict
[
str
,
str
]):
def
create_lora_update_result
(
self
,
success
:
bool
,
error_message
:
str
=
""
)
->
LoRAUpdateResult
:
return
LoRAUpdateResult
(
success
=
success
,
error_message
=
error_message
,
loaded_adapters
=
{
name
:
config
.
path
for
name
,
config
in
self
.
configs
.
items
()
},
)
def
load_lora_adapters
(
self
,
lora_paths
:
Dict
[
str
,
str
])
->
LoRAUpdateResult
:
"""
"""
Load LoRA adapters from the specified paths.
Load LoRA adapters from the specified paths.
TODO (lifuhuang): This method should be exposed to the server/engine API to support dynamic LoRA loading.
Args:
Args:
lora_paths (Dict[str, str]): A dictionary mapping LoRA adapter names to their file paths.
lora_paths (Dict[str, str]): A dictionary mapping LoRA adapter names to their file paths.
If a LoRA adapter is already loaded, it will be skipped with a warning.
If a LoRA adapter is already loaded, it will be skipped with a warning.
"""
"""
results
=
[]
for
lora_name
,
lora_path
in
lora_paths
.
items
():
for
lora_name
,
lora_path
in
lora_paths
.
items
():
if
lora_name
in
self
.
loras
:
result
=
self
.
load_lora_adapter
(
lora_name
,
lora_path
,
update_state
=
False
)
logger
.
warning
(
results
.
append
(
result
)
f
"LoRA adapter
{
lora_name
}
is already loaded."
"If you want to reload it, please unload it first."
self
.
update_state_from_configs
()
)
continue
return
self
.
create_lora_update_result
(
success
=
all
(
result
.
success
for
result
in
results
),
error_message
=
"
\n
"
.
join
(
result
.
error_message
for
result
in
results
if
not
result
.
success
),
)
def
load_lora_adapter
(
self
,
lora_name
:
str
,
lora_path
:
str
,
update_state
:
bool
=
True
)
->
LoRAUpdateResult
:
"""
Load a single LoRA adapter from the specified path.
Args:
lora_name (str): The name of the LoRA adapter.
lora_path (str): The file path to the LoRA adapter.
update_state (bool): Whether to refresh the internal state after loading the adapter. This is useful for batch loading.
"""
success
=
True
error_message
=
""
if
lora_name
in
self
.
loras
:
success
=
False
error_message
=
f
"LoRA adapter
{
lora_name
}
is skipped as it is already loaded. If you want to reload it, please unload it first."
try
:
self
.
configs
[
lora_name
]
=
LoRAConfig
(
lora_path
)
self
.
configs
[
lora_name
]
=
LoRAConfig
(
lora_path
)
except
Exception
as
e
:
success
=
False
error_message
=
(
f
"Failed to load LoRA adapter
{
lora_name
}
from
{
lora_path
}
:
{
str
(
e
)
}
"
)
self
.
update_state_from_configs
()
if
update_state
:
self
.
update_state_from_configs
()
return
self
.
create_lora_update_result
(
success
=
success
,
error_message
=
error_message
,
)
def
unload_lora_adapter
s
(
self
,
lora_name
s
:
Set
[
str
])
:
def
unload_lora_adapter
(
self
,
lora_name
:
str
)
->
LoRAUpdateResult
:
"""
"""
Unload LoRA adapters by their names. This will remove the adapters from the memory pool and
Unload LoRA adapters by their names. This will remove the adapters from the memory pool and
delete the corresponding LoRA modules.
delete the corresponding LoRA modules.
Args:
lora_names (Set[str]): A set of LoRA adapter names to unload.
"""
"""
for
lora_name
in
lora_names
:
if
lora_name
in
self
.
loras
:
success
=
True
del
self
.
configs
[
lora_name
]
error_message
=
""
else
:
if
lora_name
in
self
.
loras
:
logger
.
warning
(
f
"LoRA adapter
{
lora_name
}
is not loaded."
)
del
self
.
configs
[
lora_name
]
else
:
error_message
=
f
"LoRA adapter
{
lora_name
}
is not loaded."
success
=
False
self
.
update_state_from_configs
()
self
.
update_state_from_configs
()
return
self
.
create_lora_update_result
(
success
=
success
,
error_message
=
error_message
,
)
def
prepare_lora_batch
(
self
,
forward_batch
:
ForwardBatch
):
def
prepare_lora_batch
(
self
,
forward_batch
:
ForwardBatch
):
# load active loras into lora memory pool
# load active loras into lora memory pool
cur_uids
=
set
(
forward_batch
.
lora_paths
)
cur_uids
=
set
(
forward_batch
.
lora_paths
)
...
@@ -372,8 +425,8 @@ class LoRAManager:
...
@@ -372,8 +425,8 @@ class LoRAManager:
lora_adapter
.
initialize_weights
()
lora_adapter
.
initialize_weights
()
self
.
loras
[
name
]
=
lora_adapter
self
.
loras
[
name
]
=
lora_adapter
# Clean up unused LoRA adapters
# Clean up unused LoRA adapters
, copying the list to avoid modifying the dict during iteration.
for
name
in
self
.
loras
:
for
name
in
list
(
self
.
loras
)
:
if
name
not
in
self
.
configs
:
if
name
not
in
self
.
configs
:
logger
.
info
(
f
"Unloading LoRA adapter
{
name
}
"
)
logger
.
info
(
f
"Unloading LoRA adapter
{
name
}
"
)
del
self
.
loras
[
name
]
del
self
.
loras
[
name
]
...
...
python/sglang/srt/managers/io_struct.py
View file @
49538d11
...
@@ -20,7 +20,7 @@ import copy
...
@@ -20,7 +20,7 @@ import copy
import
uuid
import
uuid
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
dataclass
,
field
from
enum
import
Enum
from
enum
import
Enum
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Set
,
Union
from
sglang.srt.multimodal.mm_utils
import
has_valid_data
from
sglang.srt.multimodal.mm_utils
import
has_valid_data
...
@@ -1002,3 +1002,27 @@ class RpcReqInput:
...
@@ -1002,3 +1002,27 @@ class RpcReqInput:
class
RpcReqOutput
:
class
RpcReqOutput
:
success
:
bool
success
:
bool
message
:
str
message
:
str
@
dataclass
class
LoadLoRAAdapterReqInput
:
# The name of the lora module to newly loaded.
lora_name
:
str
# The path of loading.
lora_path
:
str
@
dataclass
class
UnloadLoRAAdapterReqInput
:
# The name of lora module to unload.
lora_name
:
str
@
dataclass
class
LoRAUpdateResult
:
success
:
bool
error_message
:
Optional
[
str
]
=
None
loaded_adapters
:
Dict
[
str
,
str
]
=
field
(
default_factory
=
dict
)
LoadLoRAAdapterReqOutput
=
UnloadLoRAAdapterReqOutput
=
LoRAUpdateResult
python/sglang/srt/managers/scheduler.py
View file @
49538d11
...
@@ -82,6 +82,8 @@ from sglang.srt.managers.io_struct import (
...
@@ -82,6 +82,8 @@ from sglang.srt.managers.io_struct import (
HealthCheckOutput
,
HealthCheckOutput
,
InitWeightsUpdateGroupReqInput
,
InitWeightsUpdateGroupReqInput
,
InitWeightsUpdateGroupReqOutput
,
InitWeightsUpdateGroupReqOutput
,
LoadLoRAAdapterReqInput
,
LoadLoRAAdapterReqOutput
,
OpenSessionReqInput
,
OpenSessionReqInput
,
OpenSessionReqOutput
,
OpenSessionReqOutput
,
ProfileReq
,
ProfileReq
,
...
@@ -99,6 +101,8 @@ from sglang.srt.managers.io_struct import (
...
@@ -99,6 +101,8 @@ from sglang.srt.managers.io_struct import (
SlowDownReqOutput
,
SlowDownReqOutput
,
TokenizedEmbeddingReqInput
,
TokenizedEmbeddingReqInput
,
TokenizedGenerateReqInput
,
TokenizedGenerateReqInput
,
UnloadLoRAAdapterReqInput
,
UnloadLoRAAdapterReqOutput
,
UpdateWeightFromDiskReqInput
,
UpdateWeightFromDiskReqInput
,
UpdateWeightFromDiskReqOutput
,
UpdateWeightFromDiskReqOutput
,
UpdateWeightsFromDistributedReqInput
,
UpdateWeightsFromDistributedReqInput
,
...
@@ -519,6 +523,8 @@ class Scheduler(
...
@@ -519,6 +523,8 @@ class Scheduler(
(
SetInternalStateReq
,
self
.
set_internal_state
),
(
SetInternalStateReq
,
self
.
set_internal_state
),
(
RpcReqInput
,
self
.
handle_rpc_request
),
(
RpcReqInput
,
self
.
handle_rpc_request
),
(
ExpertDistributionReq
,
self
.
expert_distribution_handle
),
(
ExpertDistributionReq
,
self
.
expert_distribution_handle
),
(
LoadLoRAAdapterReqInput
,
self
.
load_lora_adapter
),
(
UnloadLoRAAdapterReqInput
,
self
.
unload_lora_adapter
),
]
]
)
)
...
@@ -2241,6 +2247,36 @@ class Scheduler(
...
@@ -2241,6 +2247,36 @@ class Scheduler(
logger
.
error
(
message
)
logger
.
error
(
message
)
return
UpdateWeightFromDiskReqOutput
(
success
,
message
,
0
)
return
UpdateWeightFromDiskReqOutput
(
success
,
message
,
0
)
def
load_lora_adapter
(
self
,
recv_req
:
LoadLoRAAdapterReqInput
)
->
LoadLoRAAdapterReqOutput
:
"""In-place loading a new lora adapter from disk or huggingface."""
result
=
self
.
tp_worker
.
load_lora_adapter
(
recv_req
)
if
result
.
success
:
flush_cache_success
=
self
.
flush_cache
()
assert
flush_cache_success
,
"Cache flush failed after loading lora adapter."
else
:
logger
.
error
(
result
.
error_message
)
return
result
def
unload_lora_adapter
(
self
,
recv_req
:
UnloadLoRAAdapterReqInput
)
->
UnloadLoRAAdapterReqOutput
:
"""Unload the lora adapter."""
result
=
self
.
tp_worker
.
unload_lora_adapter
(
recv_req
)
if
result
.
success
:
flush_cache_success
=
self
.
flush_cache
()
assert
(
flush_cache_success
),
"Cache flush failed after unloading LoRA weights"
else
:
logger
.
error
(
result
.
error_message
)
return
result
def
init_weights_update_group
(
self
,
recv_req
:
InitWeightsUpdateGroupReqInput
):
def
init_weights_update_group
(
self
,
recv_req
:
InitWeightsUpdateGroupReqInput
):
"""Initialize the online model parameter update group."""
"""Initialize the online model parameter update group."""
success
,
message
=
self
.
tp_worker
.
init_weights_update_group
(
recv_req
)
success
,
message
=
self
.
tp_worker
.
init_weights_update_group
(
recv_req
)
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
49538d11
...
@@ -83,6 +83,9 @@ from sglang.srt.managers.io_struct import (
...
@@ -83,6 +83,9 @@ from sglang.srt.managers.io_struct import (
HealthCheckOutput
,
HealthCheckOutput
,
InitWeightsUpdateGroupReqInput
,
InitWeightsUpdateGroupReqInput
,
InitWeightsUpdateGroupReqOutput
,
InitWeightsUpdateGroupReqOutput
,
LoadLoRAAdapterReqInput
,
LoadLoRAAdapterReqOutput
,
LoRAUpdateResult
,
OpenSessionReqInput
,
OpenSessionReqInput
,
OpenSessionReqOutput
,
OpenSessionReqOutput
,
ProfileReq
,
ProfileReq
,
...
@@ -99,6 +102,8 @@ from sglang.srt.managers.io_struct import (
...
@@ -99,6 +102,8 @@ from sglang.srt.managers.io_struct import (
SlowDownReqOutput
,
SlowDownReqOutput
,
TokenizedEmbeddingReqInput
,
TokenizedEmbeddingReqInput
,
TokenizedGenerateReqInput
,
TokenizedGenerateReqInput
,
UnloadLoRAAdapterReqInput
,
UnloadLoRAAdapterReqOutput
,
UpdateWeightFromDiskReqInput
,
UpdateWeightFromDiskReqInput
,
UpdateWeightFromDiskReqOutput
,
UpdateWeightFromDiskReqOutput
,
UpdateWeightsFromDistributedReqInput
,
UpdateWeightsFromDistributedReqInput
,
...
@@ -311,6 +316,9 @@ class TokenizerManager:
...
@@ -311,6 +316,9 @@ class TokenizerManager:
self
.
expert_distribution_communicator
=
_Communicator
(
self
.
expert_distribution_communicator
=
_Communicator
(
self
.
send_to_scheduler
,
server_args
.
dp_size
self
.
send_to_scheduler
,
server_args
.
dp_size
)
)
self
.
update_lora_adapter_communicator
=
_Communicator
(
self
.
send_to_scheduler
,
server_args
.
dp_size
)
self
.
_result_dispatcher
=
TypeBasedDispatcher
(
self
.
_result_dispatcher
=
TypeBasedDispatcher
(
[
[
...
@@ -377,6 +385,10 @@ class TokenizerManager:
...
@@ -377,6 +385,10 @@ class TokenizerManager:
ExpertDistributionReqOutput
,
ExpertDistributionReqOutput
,
self
.
expert_distribution_communicator
.
handle_recv
,
self
.
expert_distribution_communicator
.
handle_recv
,
),
),
(
LoRAUpdateResult
,
self
.
update_lora_adapter_communicator
.
handle_recv
,
),
(
HealthCheckOutput
,
lambda
x
:
None
),
(
HealthCheckOutput
,
lambda
x
:
None
),
]
]
)
)
...
@@ -960,6 +972,49 @@ class TokenizerManager:
...
@@ -960,6 +972,49 @@ class TokenizerManager:
result
=
(
await
self
.
update_weights_from_tensor_communicator
(
obj
))[
0
]
result
=
(
await
self
.
update_weights_from_tensor_communicator
(
obj
))[
0
]
return
result
.
success
,
result
.
message
return
result
.
success
,
result
.
message
async
def
load_lora_adapter
(
self
,
obj
:
LoadLoRAAdapterReqInput
,
_
:
Optional
[
fastapi
.
Request
]
=
None
,
)
->
LoadLoRAAdapterReqOutput
:
self
.
auto_create_handle_loop
()
# TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
# with dp_size > 1.
assert
(
self
.
server_args
.
dp_size
==
1
),
"dp_size must be 1 for dynamic lora loading"
logger
.
info
(
"Start load Lora adapter. Lora name=%s, path=%s"
,
obj
.
lora_name
,
obj
.
lora_path
,
)
async
with
self
.
model_update_lock
.
writer_lock
:
result
=
(
await
self
.
update_lora_adapter_communicator
(
obj
))[
0
]
return
result
async
def
unload_lora_adapter
(
self
,
obj
:
UnloadLoRAAdapterReqInput
,
_
:
Optional
[
fastapi
.
Request
]
=
None
,
)
->
UnloadLoRAAdapterReqOutput
:
self
.
auto_create_handle_loop
()
# TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
# with dp_size > 1.
assert
(
self
.
server_args
.
dp_size
==
1
),
"dp_size must be 1 for dynamic lora loading"
logger
.
info
(
"Start unload Lora adapter. Lora name=%s"
,
obj
.
lora_name
,
)
async
with
self
.
model_update_lock
.
writer_lock
:
result
=
(
await
self
.
update_lora_adapter_communicator
(
obj
))[
0
]
return
result
async
def
get_weights_by_name
(
async
def
get_weights_by_name
(
self
,
obj
:
GetWeightsByNameReqInput
,
request
:
Optional
[
fastapi
.
Request
]
=
None
self
,
obj
:
GetWeightsByNameReqInput
,
request
:
Optional
[
fastapi
.
Request
]
=
None
):
):
...
...
python/sglang/srt/managers/tp_worker.py
View file @
49538d11
...
@@ -30,6 +30,8 @@ from sglang.srt.layers.logits_processor import LogitsProcessorOutput
...
@@ -30,6 +30,8 @@ from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from
sglang.srt.managers.io_struct
import
(
from
sglang.srt.managers.io_struct
import
(
GetWeightsByNameReqInput
,
GetWeightsByNameReqInput
,
InitWeightsUpdateGroupReqInput
,
InitWeightsUpdateGroupReqInput
,
LoadLoRAAdapterReqInput
,
UnloadLoRAAdapterReqInput
,
UpdateWeightFromDiskReqInput
,
UpdateWeightFromDiskReqInput
,
UpdateWeightsFromDistributedReqInput
,
UpdateWeightsFromDistributedReqInput
,
UpdateWeightsFromTensorReqInput
,
UpdateWeightsFromTensorReqInput
,
...
@@ -275,3 +277,13 @@ class TpModelWorker:
...
@@ -275,3 +277,13 @@ class TpModelWorker:
recv_req
.
name
,
recv_req
.
truncate_size
recv_req
.
name
,
recv_req
.
truncate_size
)
)
return
parameter
return
parameter
def
load_lora_adapter
(
self
,
recv_req
:
LoadLoRAAdapterReqInput
):
result
=
self
.
model_runner
.
load_lora_adapter
(
recv_req
.
lora_name
,
recv_req
.
lora_path
)
return
result
def
unload_lora_adapter
(
self
,
recv_req
:
UnloadLoRAAdapterReqInput
):
result
=
self
.
model_runner
.
unload_lora_adapter
(
recv_req
.
lora_name
)
return
result
python/sglang/srt/managers/tp_worker_overlap_thread.py
View file @
49538d11
...
@@ -26,6 +26,8 @@ import torch
...
@@ -26,6 +26,8 @@ import torch
from
sglang.srt.managers.io_struct
import
(
from
sglang.srt.managers.io_struct
import
(
GetWeightsByNameReqInput
,
GetWeightsByNameReqInput
,
InitWeightsUpdateGroupReqInput
,
InitWeightsUpdateGroupReqInput
,
LoadLoRAAdapterReqInput
,
UnloadLoRAAdapterReqInput
,
UpdateWeightFromDiskReqInput
,
UpdateWeightFromDiskReqInput
,
UpdateWeightsFromDistributedReqInput
,
UpdateWeightsFromDistributedReqInput
,
UpdateWeightsFromTensorReqInput
,
UpdateWeightsFromTensorReqInput
,
...
@@ -268,6 +270,12 @@ class TpModelWorkerClient:
...
@@ -268,6 +270,12 @@ class TpModelWorkerClient:
def
get_weights_by_name
(
self
,
recv_req
:
GetWeightsByNameReqInput
):
def
get_weights_by_name
(
self
,
recv_req
:
GetWeightsByNameReqInput
):
return
self
.
worker
.
get_weights_by_name
(
recv_req
)
return
self
.
worker
.
get_weights_by_name
(
recv_req
)
def
load_lora_adapter
(
self
,
recv_req
:
LoadLoRAAdapterReqInput
):
return
self
.
worker
.
load_lora_adapter
(
recv_req
)
def
unload_lora_adapter
(
self
,
recv_req
:
UnloadLoRAAdapterReqInput
):
return
self
.
worker
.
unload_lora_adapter
(
recv_req
)
def
__delete__
(
self
):
def
__delete__
(
self
):
self
.
input_queue
.
put
((
None
,
None
))
self
.
input_queue
.
put
((
None
,
None
))
self
.
copy_queue
.
put
((
None
,
None
,
None
))
self
.
copy_queue
.
put
((
None
,
None
,
None
))
python/sglang/srt/model_executor/model_runner.py
View file @
49538d11
...
@@ -26,7 +26,6 @@ from typing import List, Optional, Tuple, Union
...
@@ -26,7 +26,6 @@ from typing import List, Optional, Tuple, Union
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
sglang.srt
import
debug_utils
from
sglang.srt.configs.device_config
import
DeviceConfig
from
sglang.srt.configs.device_config
import
DeviceConfig
from
sglang.srt.configs.load_config
import
LoadConfig
from
sglang.srt.configs.load_config
import
LoadConfig
from
sglang.srt.configs.model_config
import
AttentionArch
,
ModelConfig
from
sglang.srt.configs.model_config
import
AttentionArch
,
ModelConfig
...
@@ -819,8 +818,47 @@ class ModelRunner:
...
@@ -819,8 +818,47 @@ class ModelRunner:
tp_size
=
self
.
tp_size
,
tp_size
=
self
.
tp_size
,
tp_rank
=
self
.
tp_rank
,
tp_rank
=
self
.
tp_rank
,
)
)
self
.
lora_manager
.
load_lora_adapters
(
self
.
server_args
.
lora_paths
)
result
=
self
.
lora_manager
.
load_lora_adapters
(
self
.
server_args
.
lora_paths
)
logger
.
info
(
"LoRA manager ready."
)
if
result
.
success
:
logger
.
info
(
f
"LoRA manager ready. Loaded LoRA adapters:
{
', '
.
join
(
result
.
loaded_adapters
)
}
"
)
else
:
raise
RuntimeError
(
f
"Failed to load LoRA adapters:
{
result
.
error_message
}
"
)
def
load_lora_adapter
(
self
,
lora_name
:
str
,
lora_path
:
str
):
"""Load a new lora adapter from disk or huggingface."""
logger
.
info
(
f
"LoRA adapter loading starts: name=
{
lora_name
}
, path=
{
lora_path
}
. "
f
"avail mem=
{
get_available_gpu_memory
(
self
.
device
,
self
.
gpu_id
):.
2
f
}
GB"
)
result
=
self
.
lora_manager
.
load_lora_adapter
(
lora_name
,
lora_path
)
logger
.
info
(
f
"LoRA adapter loading completes: name=
{
lora_name
}
, path=
{
lora_path
}
. "
f
"avail mem=
{
get_available_gpu_memory
(
self
.
device
,
self
.
gpu_id
):.
2
f
}
GB"
)
return
result
def
unload_lora_adapter
(
self
,
lora_name
:
str
):
"""Unload a lora adapter that was previously loaded during initialization or dynamic loading."""
logger
.
info
(
f
"LoRA adapter unloading starts: name=
{
lora_name
}
. "
f
"avail mem=
{
get_available_gpu_memory
(
self
.
device
,
self
.
gpu_id
):.
2
f
}
GB"
)
result
=
self
.
lora_manager
.
unload_lora_adapter
(
lora_name
)
logger
.
info
(
f
"LoRA adapter unloading completes: name=
{
lora_name
}
. "
f
"avail mem=
{
get_available_gpu_memory
(
self
.
device
,
self
.
gpu_id
):.
2
f
}
GB"
)
return
result
def
profile_max_num_token
(
self
,
total_gpu_memory
:
int
):
def
profile_max_num_token
(
self
,
total_gpu_memory
:
int
):
available_gpu_memory
=
get_available_gpu_memory
(
available_gpu_memory
=
get_available_gpu_memory
(
...
...
python/sglang/test/runners.py
View file @
49538d11
...
@@ -503,6 +503,7 @@ class SRTRunner:
...
@@ -503,6 +503,7 @@ class SRTRunner:
disable_overlap_schedule
:
bool
=
False
,
disable_overlap_schedule
:
bool
=
False
,
disable_custom_all_reduce
:
bool
=
False
,
disable_custom_all_reduce
:
bool
=
False
,
torchao_config
:
Optional
[
str
]
=
None
,
torchao_config
:
Optional
[
str
]
=
None
,
cuda_graph_max_bs
:
int
=
4
,
sleep_on_idle
=
False
,
sleep_on_idle
=
False
,
):
):
self
.
model_type
=
model_type
self
.
model_type
=
model_type
...
@@ -539,7 +540,7 @@ class SRTRunner:
...
@@ -539,7 +540,7 @@ class SRTRunner:
tokenizer_path
=
tokenizer_path
,
tokenizer_path
=
tokenizer_path
,
enable_ep_moe
=
enable_ep_moe
,
enable_ep_moe
=
enable_ep_moe
,
disable_overlap_schedule
=
disable_overlap_schedule
,
disable_overlap_schedule
=
disable_overlap_schedule
,
cuda_graph_max_bs
=
4
,
cuda_graph_max_bs
=
cuda_graph_max_bs
,
disable_custom_all_reduce
=
disable_custom_all_reduce
,
disable_custom_all_reduce
=
disable_custom_all_reduce
,
sleep_on_idle
=
sleep_on_idle
,
sleep_on_idle
=
sleep_on_idle
,
**
spec_kwargs
,
**
spec_kwargs
,
...
@@ -552,6 +553,12 @@ class SRTRunner:
...
@@ -552,6 +553,12 @@ class SRTRunner:
else
:
else
:
self
.
tokenizer
=
None
self
.
tokenizer
=
None
def
load_lora_adapter
(
self
,
lora_name
:
str
,
lora_path
:
str
):
return
self
.
engine
.
load_lora_adapter
(
lora_name
,
lora_path
)
def
unload_lora_adapter
(
self
,
lora_name
:
str
):
return
self
.
engine
.
unload_lora_adapter
(
lora_name
)
def
forward
(
def
forward
(
self
,
self
,
prompts
:
Union
[
prompts
:
Union
[
...
...
test/srt/models/lora/test_lora_update.py
0 → 100644
View file @
49538d11
This diff is collapsed.
Click to expand it.
test/srt/run_suite.py
View file @
49538d11
...
@@ -17,6 +17,7 @@ suites = {
...
@@ -17,6 +17,7 @@ suites = {
TestFile
(
"models/lora/test_lora_backend.py"
,
99
),
TestFile
(
"models/lora/test_lora_backend.py"
,
99
),
TestFile
(
"models/lora/test_multi_lora_backend.py"
,
60
),
TestFile
(
"models/lora/test_multi_lora_backend.py"
,
60
),
TestFile
(
"models/lora/test_lora_cuda_graph.py"
,
250
),
TestFile
(
"models/lora/test_lora_cuda_graph.py"
,
250
),
TestFile
(
"models/lora/test_lora_update.py"
,
400
),
TestFile
(
"models/test_embedding_models.py"
,
73
),
TestFile
(
"models/test_embedding_models.py"
,
73
),
# TestFile("models/test_clip_models.py", 52),
# TestFile("models/test_clip_models.py", 52),
TestFile
(
"models/test_encoder_embedding_models.py"
,
100
),
TestFile
(
"models/test_encoder_embedding_models.py"
,
100
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment