Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
9fdc6d6a
Unverified
Commit
9fdc6d6a
authored
Mar 28, 2025
by
Qiaolin Yu
Committed by
GitHub
Mar 27, 2025
Browse files
Fix the lora adapter when lora path is none (#4799)
Co-authored-by:
Beichen Ma
<
mabeichen12@gmail.com
>
parent
42a45df0
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
17 additions
and
14 deletions
+17
-14
python/sglang/srt/lora/lora_manager.py
python/sglang/srt/lora/lora_manager.py
+0
-4
python/sglang/srt/lora/mem_pool.py
python/sglang/srt/lora/mem_pool.py
+1
-1
test/srt/models/lora/test_lora.py
test/srt/models/lora/test_lora.py
+16
-9
No files found.
python/sglang/srt/lora/lora_manager.py
View file @
9fdc6d6a
...
@@ -133,10 +133,6 @@ class LoRAManager:
...
@@ -133,10 +133,6 @@ class LoRAManager:
assert
len
(
cur_uids
)
<=
self
.
max_loras_per_batch
assert
len
(
cur_uids
)
<=
self
.
max_loras_per_batch
self
.
memory_pool
.
prepare_lora_batch
(
cur_uids
,
self
.
loras
)
self
.
memory_pool
.
prepare_lora_batch
(
cur_uids
,
self
.
loras
)
# FIXME: Handle lora uid with None more safely
if
cur_uids
==
set
([
None
]):
return
# set up batch info shared by all lora moruldes
# set up batch info shared by all lora moruldes
bs
=
forward_batch
.
batch_size
bs
=
forward_batch
.
batch_size
seg_lens
=
(
seg_lens
=
(
...
...
python/sglang/srt/lora/mem_pool.py
View file @
9fdc6d6a
...
@@ -163,7 +163,7 @@ class LoRAMemoryPool:
...
@@ -163,7 +163,7 @@ class LoRAMemoryPool:
if
uid
is
None
:
if
uid
is
None
:
for
i
in
range
(
self
.
num_layer
):
for
i
in
range
(
self
.
num_layer
):
for
k
in
self
.
A_buffer
.
keys
():
for
k
in
self
.
A_buffer
.
keys
():
self
.
A_buffer
[
k
][
i
][
buffer_id
]
*
=
0
self
.
A_buffer
[
k
][
i
][
buffer_id
]
=
0
return
return
assert
lora_adapter
is
not
None
assert
lora_adapter
is
not
None
...
...
test/srt/models/lora/test_lora.py
View file @
9fdc6d6a
...
@@ -96,6 +96,11 @@ class TestLoRA(CustomTestCase):
...
@@ -96,6 +96,11 @@ class TestLoRA(CustomTestCase):
srt_outputs
=
srt_runner
.
forward
(
srt_outputs
=
srt_runner
.
forward
(
prompts
,
max_new_tokens
=
max_new_tokens
,
lora_paths
=
batch_lora_paths
prompts
,
max_new_tokens
=
max_new_tokens
,
lora_paths
=
batch_lora_paths
)
)
srt_outputs_lora_path_none
=
srt_runner
.
forward
(
prompts
,
max_new_tokens
=
max_new_tokens
,
lora_paths
=
[
None
]
*
len
(
prompts
),
)
with
HFRunner
(
with
HFRunner
(
base_path
,
torch_dtype
=
torch_dtype
,
model_type
=
"generation"
base_path
,
torch_dtype
=
torch_dtype
,
model_type
=
"generation"
...
@@ -169,18 +174,20 @@ class TestLoRA(CustomTestCase):
...
@@ -169,18 +174,20 @@ class TestLoRA(CustomTestCase):
print
(
f
"
{
srt_outputs
.
output_strs
=
}
"
)
print
(
f
"
{
srt_outputs
.
output_strs
=
}
"
)
print
(
f
"
{
hf_no_lora_outputs
.
output_strs
=
}
"
)
print
(
f
"
{
hf_no_lora_outputs
.
output_strs
=
}
"
)
print
(
f
"
{
srt_no_lora_outputs
.
output_strs
=
}
"
)
print
(
f
"
{
srt_no_lora_outputs
.
output_strs
=
}
"
)
print
(
f
"
{
srt_outputs_lora_path_none
.
output_strs
=
}
"
)
for
i
in
range
(
len
(
prompts
)):
for
i
in
range
(
len
(
prompts
)):
assert
srt_outputs
.
output_strs
[
i
].
strip
(
" "
)
==
hf_outputs
.
output_strs
[
i
],
(
assert
srt_outputs
.
output_strs
[
i
].
strip
(
" "
)
==
hf_outputs
.
output_strs
[
i
],
(
srt_outputs
.
output_strs
[
i
].
strip
(
" "
),
srt_outputs
.
output_strs
[
i
].
strip
(
" "
),
hf_outputs
.
output_strs
[
i
],
hf_outputs
.
output_strs
[
i
],
)
)
# assert (
assert
(
# srt_no_lora_outputs.output_strs[i].strip(" ")
srt_no_lora_outputs
.
output_strs
[
i
].
strip
(
" "
)
# == hf_no_lora_outputs.output_strs[i]
==
hf_no_lora_outputs
.
output_strs
[
i
]
# ), (
),
(
# srt_no_lora_outputs.output_strs[i].strip(" "),
srt_no_lora_outputs
.
output_strs
[
i
].
strip
(
" "
),
# hf_no_lora_outputs.output_strs[i],
hf_no_lora_outputs
.
output_strs
[
i
],
# )
)
assert
srt_outputs_lora_path_none
==
srt_no_lora_outputs
def
serving
(
self
,
prompts
,
lora_set
,
tp_size
,
torch_dtype
,
max_new_tokens
):
def
serving
(
self
,
prompts
,
lora_set
,
tp_size
,
torch_dtype
,
max_new_tokens
):
print
(
"=================== testing serving ======================="
)
print
(
"=================== testing serving ======================="
)
...
@@ -257,7 +264,7 @@ class TestLoRA(CustomTestCase):
...
@@ -257,7 +264,7 @@ class TestLoRA(CustomTestCase):
srt_no_lora_logprobs
=
torch
.
Tensor
(
srt_no_lora_logprobs
=
torch
.
Tensor
(
srt_no_lora_outputs
.
top_input_logprobs
[
i
]
srt_no_lora_outputs
.
top_input_logprobs
[
i
]
)
)
srt_logprobs
=
torch
.
u
ensor
(
srt_outputs
.
top_input_logprobs
[
i
])
srt_logprobs
=
torch
.
T
ensor
(
srt_outputs
.
top_input_logprobs
[
i
])
print
(
"max_diff"
,
torch
.
max
(
abs
(
srt_no_lora_logprobs
-
srt_logprobs
)))
print
(
"max_diff"
,
torch
.
max
(
abs
(
srt_no_lora_logprobs
-
srt_logprobs
)))
print
(
f
"
{
srt_no_lora_outputs
.
output_strs
=
}
"
)
print
(
f
"
{
srt_no_lora_outputs
.
output_strs
=
}
"
)
...
@@ -280,7 +287,7 @@ class TestLoRA(CustomTestCase):
...
@@ -280,7 +287,7 @@ class TestLoRA(CustomTestCase):
tp_size
=
1
tp_size
=
1
max_new_tokens
=
32
max_new_tokens
=
32
self
.
inference
(
PROMPTS
,
lora_set
,
tp_size
,
torch_dtype
,
max_new_tokens
)
self
.
inference
(
PROMPTS
,
lora_set
,
tp_size
,
torch_dtype
,
max_new_tokens
)
#
self.serving(PROMPTS, lora_set, tp_size, torch_dtype, max_new_tokens)
self
.
serving
(
PROMPTS
,
lora_set
,
tp_size
,
torch_dtype
,
max_new_tokens
)
# self.base_inference(
# self.base_inference(
# PROMPTS, lora_set, tp_size, torch_dtype, max_new_tokens
# PROMPTS, lora_set, tp_size, torch_dtype, max_new_tokens
# )
# )
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment