Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ac900c89
Unverified
Commit
ac900c89
authored
Feb 20, 2026
by
Cyrus Leung
Committed by
GitHub
Feb 19, 2026
Browse files
[Refactor] Implement output type check in LLM (#34794)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
76df6072
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
58 additions
and
38 deletions
+58
-38
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+58
-34
vllm/v1/engine/llm_engine.py
vllm/v1/engine/llm_engine.py
+0
-4
No files found.
vllm/entrypoints/llm.py
View file @
ac900c89
...
...
@@ -10,7 +10,7 @@ import cloudpickle
import
torch.nn
as
nn
from
pydantic
import
ValidationError
from
tqdm.auto
import
tqdm
from
typing_extensions
import
TypeVar
from
typing_extensions
import
TypeVar
,
overload
from
vllm.beam_search
import
(
BeamSearchInstance
,
...
...
@@ -94,6 +94,11 @@ if TYPE_CHECKING:
logger
=
init_logger
(
__name__
)
_O
=
TypeVar
(
"_O"
,
bound
=
RequestOutput
|
PoolingRequestOutput
,
default
=
RequestOutput
|
PoolingRequestOutput
,
)
_P
=
TypeVar
(
"_P"
,
bound
=
SamplingParams
|
PoolingParams
|
None
)
_R
=
TypeVar
(
"_R"
,
default
=
Any
)
...
...
@@ -447,17 +452,16 @@ class LLM:
if
sampling_params
is
None
:
sampling_params
=
self
.
get_default_sampling_params
()
outputs
=
self
.
_run_completion
(
return
self
.
_run_completion
(
prompts
=
prompts
,
params
=
sampling_params
,
output_type
=
RequestOutput
,
use_tqdm
=
use_tqdm
,
lora_request
=
lora_request
,
tokenization_kwargs
=
tokenization_kwargs
,
priority
=
priority
,
)
return
self
.
engine_class
.
validate_outputs
(
outputs
,
RequestOutput
)
def
enqueue
(
self
,
prompts
:
PromptType
|
Sequence
[
PromptType
],
...
...
@@ -524,23 +528,43 @@ class LLM:
return
request_ids
@
overload
def
wait_for_completion
(
self
,
*
,
use_tqdm
:
bool
|
Callable
[...,
tqdm
]
=
True
,
)
->
list
[
RequestOutput
]:
)
->
list
[
RequestOutput
|
PoolingRequestOutput
]:
...
@
overload
def
wait_for_completion
(
self
,
output_type
:
type
[
_O
]
|
tuple
[
type
[
_O
],
...],
*
,
use_tqdm
:
bool
|
Callable
[...,
tqdm
]
=
True
,
)
->
list
[
_O
]:
...
def
wait_for_completion
(
self
,
output_type
:
type
[
Any
]
|
tuple
[
type
[
Any
],
...]
|
None
=
None
,
*
,
use_tqdm
:
bool
|
Callable
[...,
tqdm
]
=
True
,
)
->
list
[
Any
]:
"""Wait for all enqueued requests to complete and return results.
This method processes all requests currently in the engine queue
and returns their outputs. Use after enqueue() to get results.
Args:
output_type: The expected output type, defaults to RequestOutput.
use_tqdm: If True, shows a tqdm progress bar.
Returns:
A list of
RequestO
utput objects for all completed requests.
A list of
o
utput objects for all completed requests.
"""
outputs
=
self
.
_run_engine
(
use_tqdm
=
use_tqdm
)
return
self
.
engine_class
.
validate_outputs
(
outputs
,
RequestOutput
)
if
output_type
is
None
:
output_type
=
(
RequestOutput
,
PoolingRequestOutput
)
return
self
.
_run_engine
(
output_type
,
use_tqdm
=
use_tqdm
)
def
_resolve_mm_lora
(
self
,
...
...
@@ -744,13 +768,13 @@ class LLM:
# only runs for one step
# we don't need to use tqdm here
raw_
output
=
self
.
_render_and_run_requests
(
output
=
self
.
_render_and_run_requests
(
prompts
=
(
beam
.
get_prompt
()
for
beam
in
all_beams
),
params
=
self
.
_params_to_seq
(
sampling_params
,
len
(
all_beams
)),
output_type
=
RequestOutput
,
lora_requests
=
[
beam
.
lora_request
for
beam
in
all_beams
],
use_tqdm
=
False
,
)
output
=
self
.
engine_class
.
validate_outputs
(
raw_output
,
RequestOutput
)
for
(
start
,
end
),
instance
in
zip
(
instance_start_and_end
,
instances_batch
...
...
@@ -987,9 +1011,10 @@ class LLM:
if
sampling_params
is
None
:
sampling_params
=
self
.
get_default_sampling_params
()
outputs
=
self
.
_run_chat
(
return
self
.
_run_chat
(
messages
=
messages
,
params
=
sampling_params
,
output_type
=
RequestOutput
,
use_tqdm
=
use_tqdm
,
lora_request
=
lora_request
,
chat_template
=
chat_template
,
...
...
@@ -1002,8 +1027,6 @@ class LLM:
mm_processor_kwargs
=
mm_processor_kwargs
,
)
return
self
.
engine_class
.
validate_outputs
(
outputs
,
RequestOutput
)
def
encode
(
self
,
prompts
:
PromptType
|
Sequence
[
PromptType
]
|
DataPrompt
,
...
...
@@ -1135,19 +1158,16 @@ class LLM:
outputs
=
self
.
_run_completion
(
prompts
=
prompts_seq
,
params
=
params_seq
,
output_type
=
PoolingRequestOutput
,
use_tqdm
=
use_tqdm
,
lora_request
=
lora_request
,
tokenization_kwargs
=
tokenization_kwargs
,
)
model_outputs
=
self
.
engine_class
.
validate_outputs
(
outputs
,
PoolingRequestOutput
)
if
use_io_processor
:
# get the post-processed model outputs
assert
self
.
io_processor
is
not
None
processed_outputs
=
self
.
io_processor
.
post_process
(
model_
outputs
)
processed_outputs
=
self
.
io_processor
.
post_process
(
outputs
)
return
[
PoolingRequestOutput
[
Any
](
...
...
@@ -1160,8 +1180,8 @@ class LLM:
finished
=
True
,
)
]
else
:
return
model_
outputs
return
outputs
def
embed
(
self
,
...
...
@@ -1353,8 +1373,7 @@ class LLM:
embed_2
=
encoded_output_2
,
)
items
=
self
.
engine_class
.
validate_outputs
(
scores
,
PoolingRequestOutput
)
return
[
ScoringRequestOutput
.
from_base
(
item
)
for
item
in
items
]
return
[
ScoringRequestOutput
.
from_base
(
item
)
for
item
in
scores
]
def
_late_interaction_score
(
self
,
...
...
@@ -1393,7 +1412,7 @@ class LLM:
)
text_2
.
append
(
text
)
encoded_output
:
list
[
PoolingRequestOutput
]
=
self
.
encode
(
encoded_output
=
self
.
encode
(
text_1
+
text_2
,
use_tqdm
=
use_tqdm
,
lora_request
=
lora_request
,
...
...
@@ -1402,8 +1421,8 @@ class LLM:
tokenization_kwargs
=
tokenization_kwargs
,
)
encoded_output_1
:
list
[
PoolingRequestOutput
]
=
encoded_output
[
0
:
len
(
text_1
)]
encoded_output_2
:
list
[
PoolingRequestOutput
]
=
encoded_output
[
len
(
text_1
)
:]
encoded_output_1
=
encoded_output
[
0
:
len
(
text_1
)]
encoded_output_2
=
encoded_output
[
len
(
text_1
)
:]
if
len
(
encoded_output_1
)
==
1
:
encoded_output_1
=
encoded_output_1
*
len
(
encoded_output_2
)
...
...
@@ -1434,8 +1453,7 @@ class LLM:
)
)
items
=
self
.
engine_class
.
validate_outputs
(
scores
,
PoolingRequestOutput
)
return
[
ScoringRequestOutput
.
from_base
(
item
)
for
item
in
items
]
return
[
ScoringRequestOutput
.
from_base
(
item
)
for
item
in
scores
]
def
_cross_encoding_score
(
self
,
...
...
@@ -1491,13 +1509,12 @@ class LLM:
outputs
=
self
.
_run_completion
(
prompts
=
prompts
,
params
=
pooling_params_list
,
output_type
=
PoolingRequestOutput
,
use_tqdm
=
use_tqdm
,
lora_request
=
lora_request
,
)
items
=
self
.
engine_class
.
validate_outputs
(
outputs
,
PoolingRequestOutput
)
return
[
ScoringRequestOutput
.
from_base
(
item
)
for
item
in
items
]
return
[
ScoringRequestOutput
.
from_base
(
item
)
for
item
in
outputs
]
def
score
(
self
,
...
...
@@ -1759,6 +1776,7 @@ class LLM:
params
:
SamplingParams
|
PoolingParams
|
Sequence
[
SamplingParams
|
PoolingParams
],
output_type
:
type
[
_O
],
*
,
use_tqdm
:
bool
|
Callable
[...,
tqdm
]
=
True
,
lora_request
:
Sequence
[
LoRARequest
]
|
LoRARequest
|
None
=
None
,
...
...
@@ -1790,6 +1808,7 @@ class LLM:
)
),
params
=
seq_params
,
output_type
=
output_type
,
use_tqdm
=
use_tqdm
,
lora_requests
=
seq_lora_requests
,
priorities
=
seq_priority
,
...
...
@@ -1802,6 +1821,7 @@ class LLM:
params
:
SamplingParams
|
PoolingParams
|
Sequence
[
SamplingParams
|
PoolingParams
],
output_type
:
type
[
_O
],
*
,
use_tqdm
:
bool
|
Callable
[...,
tqdm
]
=
True
,
lora_request
:
Sequence
[
LoRARequest
]
|
LoRARequest
|
None
=
None
,
...
...
@@ -1848,6 +1868,7 @@ class LLM:
)
),
params
=
seq_params
,
output_type
=
output_type
,
lora_requests
=
seq_lora_requests
,
use_tqdm
=
use_tqdm
,
)
...
...
@@ -1856,6 +1877,7 @@ class LLM:
self
,
prompts
:
Iterable
[
ProcessorInputs
],
params
:
Sequence
[
SamplingParams
|
PoolingParams
],
output_type
:
type
[
_O
],
*
,
lora_requests
:
Sequence
[
LoRARequest
|
None
]
|
None
=
None
,
priorities
:
Sequence
[
int
]
|
None
=
None
,
...
...
@@ -1878,7 +1900,7 @@ class LLM:
priorities
=
priorities
,
)
return
self
.
_run_engine
(
use_tqdm
=
use_tqdm
)
return
self
.
_run_engine
(
output_type
,
use_tqdm
=
use_tqdm
)
def
_render_and_add_requests
(
self
,
...
...
@@ -1932,9 +1954,10 @@ class LLM:
def
_run_engine
(
self
,
output_type
:
type
[
_O
]
|
tuple
[
type
[
_O
],
...],
*
,
use_tqdm
:
bool
|
Callable
[...,
tqdm
]
=
True
,
)
->
list
[
RequestOutput
|
PoolingRequestOutput
]:
)
->
list
[
_O
]:
# Initialize tqdm.
if
use_tqdm
:
num_requests
=
self
.
llm_engine
.
get_num_unfinished_requests
()
...
...
@@ -1947,14 +1970,15 @@ class LLM:
)
# Run the engine.
outputs
:
list
[
RequestOutput
|
PoolingRequestOutput
]
=
[]
outputs
:
list
[
_O
]
=
[]
total_in_toks
=
0
total_out_toks
=
0
while
self
.
llm_engine
.
has_unfinished_requests
():
step_outputs
=
self
.
llm_engine
.
step
()
for
output
in
step_outputs
:
assert
isinstance
(
output
,
output_type
)
if
output
.
finished
:
outputs
.
append
(
output
)
outputs
.
append
(
output
)
# type: ignore[arg-type]
if
use_tqdm
:
if
isinstance
(
output
,
RequestOutput
):
# Calculate tokens only for RequestOutput
...
...
vllm/v1/engine/llm_engine.py
View file @
ac900c89
...
...
@@ -199,10 +199,6 @@ class LLMEngine:
self
.
should_execute_dummy_batch
=
True
return
aggregated_has_unfinished
@
classmethod
def
validate_outputs
(
cls
,
outputs
,
output_type
):
return
outputs
def
get_supported_tasks
(
self
)
->
tuple
[
SupportedTask
,
...]:
if
not
hasattr
(
self
,
"_supported_tasks"
):
# Cache the result
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment