Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
1a08358a
"docs/vscode:/vscode.git/clone" did not exist on "fcf5ad5f494eed6542d7a236346dcd7b359fad09"
Unverified
Commit
1a08358a
authored
Jul 01, 2025
by
Lifu Huang
Committed by
GitHub
Jul 01, 2025
Browse files
Improve error handling for requests with unloaded LoRA path(s) (#7642)
parent
f18a8fdd
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
135 additions
and
20 deletions
+135
-20
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+25
-0
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+2
-2
test/srt/models/lora/test_lora_update.py
test/srt/models/lora/test_lora_update.py
+108
-18
No files found.
python/sglang/srt/managers/tokenizer_manager.py
View file @
1a08358a
...
@@ -240,6 +240,12 @@ class TokenizerManager:
...
@@ -240,6 +240,12 @@ class TokenizerManager:
revision
=
server_args
.
revision
,
revision
=
server_args
.
revision
,
)
)
# Initialize loaded loRA adapters with the initial lora paths in the server_args.
# This list will be updated when new LoRA adapters are loaded or unloaded dynamically.
self
.
loaded_lora_adapters
:
Dict
[
str
,
str
]
=
dict
(
self
.
server_args
.
lora_paths
or
{}
)
# Store states
# Store states
self
.
no_create_loop
=
False
self
.
no_create_loop
=
False
self
.
rid_to_state
:
Dict
[
str
,
ReqState
]
=
{}
self
.
rid_to_state
:
Dict
[
str
,
ReqState
]
=
{}
...
@@ -549,6 +555,8 @@ class TokenizerManager:
...
@@ -549,6 +555,8 @@ class TokenizerManager:
"The server is not configured to enable custom logit processor. "
"The server is not configured to enable custom logit processor. "
"Please set `--enable-custom-logits-processor` to enable this feature."
"Please set `--enable-custom-logits-processor` to enable this feature."
)
)
if
self
.
server_args
.
lora_paths
and
obj
.
lora_path
:
self
.
_validate_lora_adapters
(
obj
)
def
_validate_input_ids_in_vocab
(
def
_validate_input_ids_in_vocab
(
self
,
input_ids
:
List
[
int
],
vocab_size
:
int
self
,
input_ids
:
List
[
int
],
vocab_size
:
int
...
@@ -662,6 +670,21 @@ class TokenizerManager:
...
@@ -662,6 +670,21 @@ class TokenizerManager:
"Batch tokenization is not needed for input_embeds. Do not set `enable_tokenizer_batch_encode`."
"Batch tokenization is not needed for input_embeds. Do not set `enable_tokenizer_batch_encode`."
)
)
def
_validate_lora_adapters
(
self
,
obj
:
GenerateReqInput
):
"""Validate that the requested LoRA adapters are loaded."""
requested_adapters
=
(
set
(
obj
.
lora_path
)
if
isinstance
(
obj
.
lora_path
,
list
)
else
{
obj
.
lora_path
}
)
loaded_adapters
=
(
self
.
loaded_lora_adapters
.
keys
()
if
self
.
loaded_lora_adapters
else
set
()
)
unloaded_adapters
=
requested_adapters
-
loaded_adapters
if
unloaded_adapters
:
raise
ValueError
(
f
"The following requested LoRA adapters are not loaded:
{
unloaded_adapters
}
\n
"
f
"Loaded adapters:
{
loaded_adapters
}
."
)
def
_send_one_request
(
def
_send_one_request
(
self
,
self
,
obj
:
Union
[
GenerateReqInput
,
EmbeddingReqInput
],
obj
:
Union
[
GenerateReqInput
,
EmbeddingReqInput
],
...
@@ -988,6 +1011,7 @@ class TokenizerManager:
...
@@ -988,6 +1011,7 @@ class TokenizerManager:
async
with
self
.
model_update_lock
.
writer_lock
:
async
with
self
.
model_update_lock
.
writer_lock
:
result
=
(
await
self
.
update_lora_adapter_communicator
(
obj
))[
0
]
result
=
(
await
self
.
update_lora_adapter_communicator
(
obj
))[
0
]
self
.
loaded_lora_adapters
=
result
.
loaded_adapters
return
result
return
result
async
def
unload_lora_adapter
(
async
def
unload_lora_adapter
(
...
@@ -1009,6 +1033,7 @@ class TokenizerManager:
...
@@ -1009,6 +1033,7 @@ class TokenizerManager:
async
with
self
.
model_update_lock
.
writer_lock
:
async
with
self
.
model_update_lock
.
writer_lock
:
result
=
(
await
self
.
update_lora_adapter_communicator
(
obj
))[
0
]
result
=
(
await
self
.
update_lora_adapter_communicator
(
obj
))[
0
]
self
.
loaded_lora_adapters
=
result
.
loaded_adapters
return
result
return
result
async
def
get_weights_by_name
(
async
def
get_weights_by_name
(
...
...
python/sglang/srt/server_args.py
View file @
1a08358a
...
@@ -20,7 +20,7 @@ import logging
...
@@ -20,7 +20,7 @@ import logging
import
os
import
os
import
random
import
random
import
tempfile
import
tempfile
from
typing
import
List
,
Literal
,
Optional
from
typing
import
List
,
Literal
,
Optional
,
Union
from
sglang.srt.hf_transformers_utils
import
check_gguf_file
,
get_config
from
sglang.srt.hf_transformers_utils
import
check_gguf_file
,
get_config
from
sglang.srt.reasoning_parser
import
ReasoningParser
from
sglang.srt.reasoning_parser
import
ReasoningParser
...
@@ -131,7 +131,7 @@ class ServerArgs:
...
@@ -131,7 +131,7 @@ class ServerArgs:
preferred_sampling_params
:
Optional
[
str
]
=
None
preferred_sampling_params
:
Optional
[
str
]
=
None
# LoRA
# LoRA
lora_paths
:
Optional
[
List
[
str
]]
=
None
lora_paths
:
Optional
[
Union
[
dict
[
str
,
str
],
List
[
str
]]
]
=
None
max_loras_per_batch
:
int
=
8
max_loras_per_batch
:
int
=
8
lora_backend
:
str
=
"triton"
lora_backend
:
str
=
"triton"
...
...
test/srt/models/lora/test_lora_update.py
View file @
1a08358a
...
@@ -16,7 +16,7 @@ import multiprocessing as mp
...
@@ -16,7 +16,7 @@ import multiprocessing as mp
import
unittest
import
unittest
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
enum
import
Enum
from
enum
import
Enum
from
typing
import
List
,
Optional
,
Union
from
typing
import
Any
,
List
,
Optional
,
Union
import
requests
import
requests
import
torch
import
torch
...
@@ -42,14 +42,16 @@ PROMPTS = [
...
@@ -42,14 +42,16 @@ PROMPTS = [
class
OperationType
(
Enum
):
class
OperationType
(
Enum
):
LOAD
=
"load"
LOAD
=
"load"
UNLOAD
=
"unload"
UNLOAD
=
"unload"
NOOP
=
"noop"
FORWARD
=
"forward"
FORWARD
=
"forward"
EXPECT_ERROR
=
"expect_error"
@
dataclass
@
dataclass
class
Operation
:
class
Operation
:
# Operation type, can be LOAD, UNLOAD, FORWARD, or EXPECT_ERROR
type
:
OperationType
type
:
OperationType
data
:
Optional
[
str
]
# Data associated with the operation. Exact type varies depending on the operation
data
:
Optional
[
Any
]
@
dataclass
@
dataclass
...
@@ -62,7 +64,7 @@ class TestCase:
...
@@ -62,7 +64,7 @@ class TestCase:
max_new_tokens
:
int
=
32
max_new_tokens
:
int
=
32
def
create_batch_data
(
adapters
:
Union
[
str
,
list
])
->
dict
:
def
create_batch_data
(
adapters
:
Union
[
str
,
list
])
->
List
[
tuple
[
str
,
str
]]
:
if
not
isinstance
(
adapters
,
list
):
if
not
isinstance
(
adapters
,
list
):
adapters
=
[
adapters
]
adapters
=
[
adapters
]
return
[(
prompt
,
adapter
)
for
prompt
in
PROMPTS
for
adapter
in
adapters
]
return
[(
prompt
,
adapter
)
for
prompt
in
PROMPTS
for
adapter
in
adapters
]
...
@@ -80,6 +82,26 @@ TEST_CASES = [
...
@@ -80,6 +82,26 @@ TEST_CASES = [
],
],
initial_adapters
=
[
"philschmid/code-llama-3-1-8b-text-to-sql-lora"
],
initial_adapters
=
[
"philschmid/code-llama-3-1-8b-text-to-sql-lora"
],
op_sequence
=
[
op_sequence
=
[
Operation
(
type
=
OperationType
.
FORWARD
,
data
=
create_batch_data
(
"philschmid/code-llama-3-1-8b-text-to-sql-lora"
),
),
Operation
(
type
=
OperationType
.
EXPECT_ERROR
,
data
=
(
create_batch_data
(
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
),
"not loaded"
,
),
),
Operation
(
type
=
OperationType
.
EXPECT_ERROR
,
data
=
(
create_batch_data
(
"pbevan11/llama-3.1-8b-ocr-correction"
),
"not loaded"
,
),
),
Operation
(
Operation
(
type
=
OperationType
.
LOAD
,
type
=
OperationType
.
LOAD
,
data
=
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
,
data
=
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
,
...
@@ -102,6 +124,13 @@ TEST_CASES = [
...
@@ -102,6 +124,13 @@ TEST_CASES = [
type
=
OperationType
.
UNLOAD
,
type
=
OperationType
.
UNLOAD
,
data
=
"philschmid/code-llama-3-1-8b-text-to-sql-lora"
,
data
=
"philschmid/code-llama-3-1-8b-text-to-sql-lora"
,
),
),
Operation
(
type
=
OperationType
.
EXPECT_ERROR
,
data
=
(
create_batch_data
(
"philschmid/code-llama-3-1-8b-text-to-sql-lora"
),
"not loaded"
,
),
),
Operation
(
Operation
(
type
=
OperationType
.
FORWARD
,
type
=
OperationType
.
FORWARD
,
data
=
create_batch_data
(
data
=
create_batch_data
(
...
@@ -115,6 +144,15 @@ TEST_CASES = [
...
@@ -115,6 +144,15 @@ TEST_CASES = [
type
=
OperationType
.
UNLOAD
,
type
=
OperationType
.
UNLOAD
,
data
=
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
,
data
=
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
,
),
),
Operation
(
type
=
OperationType
.
EXPECT_ERROR
,
data
=
(
create_batch_data
(
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
),
"not loaded"
,
),
),
Operation
(
Operation
(
type
=
OperationType
.
FORWARD
,
type
=
OperationType
.
FORWARD
,
data
=
create_batch_data
(
"pbevan11/llama-3.1-8b-ocr-correction"
),
data
=
create_batch_data
(
"pbevan11/llama-3.1-8b-ocr-correction"
),
...
@@ -149,6 +187,22 @@ TEST_CASES = [
...
@@ -149,6 +187,22 @@ TEST_CASES = [
type
=
OperationType
.
FORWARD
,
type
=
OperationType
.
FORWARD
,
data
=
create_batch_data
(
"philschmid/code-llama-3-1-8b-text-to-sql-lora"
),
data
=
create_batch_data
(
"philschmid/code-llama-3-1-8b-text-to-sql-lora"
),
),
),
Operation
(
type
=
OperationType
.
EXPECT_ERROR
,
data
=
(
create_batch_data
(
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
),
"not loaded"
,
),
),
Operation
(
type
=
OperationType
.
EXPECT_ERROR
,
data
=
(
create_batch_data
(
"pbevan11/llama-3.1-8b-ocr-correction"
),
"not loaded"
,
),
),
Operation
(
Operation
(
type
=
OperationType
.
LOAD
,
type
=
OperationType
.
LOAD
,
data
=
"pbevan11/llama-3.1-8b-ocr-correction"
,
data
=
"pbevan11/llama-3.1-8b-ocr-correction"
,
...
@@ -157,6 +211,13 @@ TEST_CASES = [
...
@@ -157,6 +211,13 @@ TEST_CASES = [
type
=
OperationType
.
UNLOAD
,
type
=
OperationType
.
UNLOAD
,
data
=
"philschmid/code-llama-3-1-8b-text-to-sql-lora"
,
data
=
"philschmid/code-llama-3-1-8b-text-to-sql-lora"
,
),
),
Operation
(
type
=
OperationType
.
EXPECT_ERROR
,
data
=
(
create_batch_data
(
"philschmid/code-llama-3-1-8b-text-to-sql-lora"
),
"not loaded"
,
),
),
Operation
(
Operation
(
type
=
OperationType
.
FORWARD
,
type
=
OperationType
.
FORWARD
,
data
=
create_batch_data
(
"pbevan11/llama-3.1-8b-ocr-correction"
),
data
=
create_batch_data
(
"pbevan11/llama-3.1-8b-ocr-correction"
),
...
@@ -332,19 +393,31 @@ class LoRAUpdateEngineTestSession(LoRAUpdateTestSessionBase):
...
@@ -332,19 +393,31 @@ class LoRAUpdateEngineTestSession(LoRAUpdateTestSessionBase):
prompts
:
List
[
str
],
prompts
:
List
[
str
],
lora_paths
:
List
[
str
],
lora_paths
:
List
[
str
],
max_new_tokens
:
int
=
32
,
max_new_tokens
:
int
=
32
,
expected_error
:
str
=
None
,
):
):
"""
"""
Perform a batch forward pass with the current set of loaded LoRA adapters.
Perform a batch forward pass with the current set of loaded LoRA adapters.
"""
"""
response
=
self
.
handle
.
batch_forward
(
try
:
prompts
=
prompts
,
response
=
self
.
handle
.
batch_forward
(
lora_paths
=
lora_paths
,
prompts
=
prompts
,
max_new_tokens
=
max_new_tokens
,
lora_paths
=
lora_paths
,
)
max_new_tokens
=
max_new_tokens
,
output_strs
=
response
.
output_strs
)
except
ValueError
as
e
:
if
expected_error
:
error_message
=
str
(
e
)
self
.
testcase
.
assertIn
(
expected_error
,
error_message
)
print
(
f
"Received error as expected:
{
error_message
}
"
)
return
error_message
raise
e
self
.
testcase
.
assertEqual
(
len
(
response
.
output_strs
),
len
(
prompts
))
output
=
response
.
output_strs
print
(
f
"output_strs:
{
output
}
"
)
print
(
f
"output_strs:
{
output_strs
}
"
)
return
output
return
output_strs
class
LoRAUpdateServerTestSession
(
LoRAUpdateTestSessionBase
):
class
LoRAUpdateServerTestSession
(
LoRAUpdateTestSessionBase
):
...
@@ -426,6 +499,7 @@ class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase):
...
@@ -426,6 +499,7 @@ class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase):
prompts
:
List
[
str
],
prompts
:
List
[
str
],
lora_paths
:
List
[
str
],
lora_paths
:
List
[
str
],
max_new_tokens
:
int
=
32
,
max_new_tokens
:
int
=
32
,
expected_error
:
str
=
None
,
):
):
"""
"""
Perform a batch forward pass with the current set of loaded LoRA adapters.
Perform a batch forward pass with the current set of loaded LoRA adapters.
...
@@ -442,11 +516,18 @@ class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase):
...
@@ -442,11 +516,18 @@ class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase):
},
},
},
},
)
)
self
.
testcase
.
assertTrue
(
response
.
ok
)
if
expected_error
:
output_strs
=
[
r
[
"text"
]
for
r
in
response
.
json
()]
self
.
testcase
.
assertEqual
(
response
.
status_code
,
400
)
self
.
testcase
.
assertIn
(
expected_error
,
response
.
text
)
print
(
f
"output_strs:
{
output_strs
}
"
)
output
=
response
.
text
return
output_strs
print
(
f
"Received error as expected:
{
response
.
text
}
"
)
return
output
else
:
self
.
testcase
.
assertTrue
(
response
.
ok
)
output
=
[
r
[
"text"
]
for
r
in
response
.
json
()]
self
.
testcase
.
assertEqual
(
len
(
output
),
len
(
prompts
))
print
(
f
"output_strs:
{
output
}
"
)
return
output
# Factory function to create the appropriate LoRA test session based on mode
# Factory function to create the appropriate LoRA test session based on mode
...
@@ -535,14 +616,23 @@ class TestLoRADynamicUpdate(CustomTestCase):
...
@@ -535,14 +616,23 @@ class TestLoRADynamicUpdate(CustomTestCase):
max_new_tokens
=
max_new_tokens
,
max_new_tokens
=
max_new_tokens
,
)
)
forward_outputs
.
append
(
result
)
forward_outputs
.
append
(
result
)
elif
op_type
==
OperationType
.
EXPECT_ERROR
:
input_data
,
expected_error
=
data
prompts
,
adapters
=
zip
(
*
input_data
)
result
=
session
.
forward
(
prompts
=
list
(
prompts
),
lora_paths
=
list
(
adapters
),
max_new_tokens
=
max_new_tokens
,
expected_error
=
expected_error
,
)
return
forward_outputs
return
forward_outputs
def
test_dynamic_adapter_updates
(
self
):
def
test_dynamic_adapter_updates
(
self
):
for
case_idx
,
test_case
in
enumerate
(
TEST_CASES
,
start
=
1
):
for
case_idx
,
test_case
in
enumerate
(
TEST_CASES
,
start
=
1
):
for
mode
in
[
for
mode
in
[
LoRAUpdateTestSessionMode
.
SERVER
,
LoRAUpdateTestSessionMode
.
ENGINE
,
LoRAUpdateTestSessionMode
.
ENGINE
,
LoRAUpdateTestSessionMode
.
SERVER
,
]:
]:
print
(
"="
*
100
)
print
(
"="
*
100
)
print
(
f
"Starting test case
{
case_idx
}
in
{
mode
.
value
}
mode."
)
print
(
f
"Starting test case
{
case_idx
}
in
{
mode
.
value
}
mode."
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment