Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0e391e75
Unverified
Commit
0e391e75
authored
Dec 16, 2025
by
Jee Jee Li
Committed by
GitHub
Dec 16, 2025
Browse files
[Bugfix] Fix RequestOutput miss lora_request (#30636)
Signed-off-by:
Jee Jee Li
<
pandaleefree@gmail.com
>
parent
0d0c929f
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
19 additions
and
7 deletions
+19
-7
tests/lora/test_gptoss_tp.py
tests/lora/test_gptoss_tp.py
+5
-1
tests/lora/test_llama_tp.py
tests/lora/test_llama_tp.py
+8
-1
vllm/v1/engine/output_processor.py
vllm/v1/engine/output_processor.py
+6
-5
No files found.
tests/lora/test_gptoss_tp.py
View file @
0e391e75
...
@@ -76,6 +76,8 @@ def test_gpt_oss_lora(gptoss20b_lora_files):
...
@@ -76,6 +76,8 @@ def test_gpt_oss_lora(gptoss20b_lora_files):
enable_lora
=
True
,
enable_lora
=
True
,
max_loras
=
4
,
max_loras
=
4
,
max_lora_rank
=
8
,
max_lora_rank
=
8
,
max_num_seqs
=
2
,
max_num_batched_tokens
=
2048
,
compilation_config
=
vllm
.
config
.
CompilationConfig
(
# Avoid OOM
compilation_config
=
vllm
.
config
.
CompilationConfig
(
# Avoid OOM
cudagraph_specialize_lora
=
False
,
cudagraph_specialize_lora
=
False
,
),
),
...
@@ -94,8 +96,10 @@ def test_gpt_oss_lora_tp2(gptoss20b_lora_files, fully_sharded_loras):
...
@@ -94,8 +96,10 @@ def test_gpt_oss_lora_tp2(gptoss20b_lora_files, fully_sharded_loras):
enable_lora
=
True
,
enable_lora
=
True
,
max_loras
=
2
,
max_loras
=
2
,
max_lora_rank
=
8
,
max_lora_rank
=
8
,
max_num_seqs
=
16
,
max_num_seqs
=
2
,
max_num_batched_tokens
=
2048
,
tensor_parallel_size
=
2
,
tensor_parallel_size
=
2
,
gpu_memory_utilization
=
0.8
,
fully_sharded_loras
=
fully_sharded_loras
,
fully_sharded_loras
=
fully_sharded_loras
,
compilation_config
=
vllm
.
config
.
CompilationConfig
(
# Avoid OOM
compilation_config
=
vllm
.
config
.
CompilationConfig
(
# Avoid OOM
cudagraph_specialize_lora
=
False
,
cudagraph_specialize_lora
=
False
,
...
...
tests/lora/test_llama_tp.py
View file @
0e391e75
...
@@ -76,11 +76,18 @@ def do_sample(
...
@@ -76,11 +76,18 @@ def do_sample(
if
lora_id
if
lora_id
else
None
,
else
None
,
)
)
# Print the outputs.
lora_request
=
LoRARequest
(
str
(
lora_id
),
lora_id
,
lora_path
)
if
lora_id
else
None
generated_texts
:
list
[
str
]
=
[]
generated_texts
:
list
[
str
]
=
[]
for
output
in
outputs
:
for
output
in
outputs
:
prompt
=
output
.
prompt
prompt
=
output
.
prompt
generated_text
=
output
.
outputs
[
0
].
text
generated_text
=
output
.
outputs
[
0
].
text
# The output should include correct lora_request info
if
lora_request
is
not
None
:
assert
output
.
lora_request
.
lora_name
==
lora_request
.
lora_name
assert
output
.
lora_request
.
lora_int_id
==
lora_request
.
lora_int_id
assert
output
.
lora_request
.
lora_path
==
lora_request
.
lora_path
else
:
assert
output
.
lora_request
is
None
generated_texts
.
append
(
generated_text
)
generated_texts
.
append
(
generated_text
)
print
(
f
"Prompt:
{
prompt
!
r
}
, Generated text:
{
generated_text
!
r
}
"
)
print
(
f
"Prompt:
{
prompt
!
r
}
, Generated text:
{
generated_text
!
r
}
"
)
return
generated_texts
return
generated_texts
...
...
vllm/v1/engine/output_processor.py
View file @
0e391e75
...
@@ -8,6 +8,7 @@ from typing import Any, cast
...
@@ -8,6 +8,7 @@ from typing import Any, cast
import
torch
import
torch
from
vllm.lora.request
import
LoRARequest
from
vllm.outputs
import
(
from
vllm.outputs
import
(
CompletionOutput
,
CompletionOutput
,
PoolingOutput
,
PoolingOutput
,
...
@@ -93,7 +94,7 @@ class RequestState:
...
@@ -93,7 +94,7 @@ class RequestState:
request_id
:
str
,
request_id
:
str
,
parent_req
:
ParentRequest
|
None
,
parent_req
:
ParentRequest
|
None
,
request_index
:
int
,
request_index
:
int
,
lora_
name
:
st
r
|
None
,
lora_
request
:
LoRAReque
st
|
None
,
output_kind
:
RequestOutputKind
,
output_kind
:
RequestOutputKind
,
prompt
:
str
|
None
,
prompt
:
str
|
None
,
prompt_token_ids
:
list
[
int
]
|
None
,
prompt_token_ids
:
list
[
int
]
|
None
,
...
@@ -112,7 +113,8 @@ class RequestState:
...
@@ -112,7 +113,8 @@ class RequestState:
self
.
request_id
=
request_id
self
.
request_id
=
request_id
self
.
parent_req
=
parent_req
self
.
parent_req
=
parent_req
self
.
request_index
=
request_index
self
.
request_index
=
request_index
self
.
lora_name
=
lora_name
self
.
lora_request
=
lora_request
self
.
lora_name
=
lora_request
.
lora_name
if
lora_request
is
not
None
else
None
self
.
output_kind
=
output_kind
self
.
output_kind
=
output_kind
self
.
prompt
=
prompt
self
.
prompt
=
prompt
self
.
prompt_token_ids
=
prompt_token_ids
self
.
prompt_token_ids
=
prompt_token_ids
...
@@ -178,9 +180,7 @@ class RequestState:
...
@@ -178,9 +180,7 @@ class RequestState:
request_id
=
request
.
request_id
,
request_id
=
request
.
request_id
,
parent_req
=
parent_req
,
parent_req
=
parent_req
,
request_index
=
request_index
,
request_index
=
request_index
,
lora_name
=
(
lora_request
=
request
.
lora_request
,
request
.
lora_request
.
name
if
request
.
lora_request
is
not
None
else
None
),
output_kind
=
output_kind
,
output_kind
=
output_kind
,
prompt
=
prompt
,
prompt
=
prompt
,
prompt_token_ids
=
request
.
prompt_token_ids
,
prompt_token_ids
=
request
.
prompt_token_ids
,
...
@@ -289,6 +289,7 @@ class RequestState:
...
@@ -289,6 +289,7 @@ class RequestState:
return
RequestOutput
(
return
RequestOutput
(
request_id
=
request_id
,
request_id
=
request_id
,
lora_request
=
self
.
lora_request
,
prompt
=
self
.
prompt
,
prompt
=
self
.
prompt
,
prompt_token_ids
=
prompt_token_ids
,
prompt_token_ids
=
prompt_token_ids
,
prompt_logprobs
=
prompt_logprobs
,
prompt_logprobs
=
prompt_logprobs
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment