Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ac900c89
Unverified
Commit
ac900c89
authored
Feb 20, 2026
by
Cyrus Leung
Committed by
GitHub
Feb 19, 2026
Browse files
[Refactor] Implement output type check in LLM (#34794)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
76df6072
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
58 additions
and
38 deletions
+58
-38
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+58
-34
vllm/v1/engine/llm_engine.py
vllm/v1/engine/llm_engine.py
+0
-4
No files found.
vllm/entrypoints/llm.py
View file @
ac900c89
...
@@ -10,7 +10,7 @@ import cloudpickle
...
@@ -10,7 +10,7 @@ import cloudpickle
import
torch.nn
as
nn
import
torch.nn
as
nn
from
pydantic
import
ValidationError
from
pydantic
import
ValidationError
from
tqdm.auto
import
tqdm
from
tqdm.auto
import
tqdm
from
typing_extensions
import
TypeVar
from
typing_extensions
import
TypeVar
,
overload
from
vllm.beam_search
import
(
from
vllm.beam_search
import
(
BeamSearchInstance
,
BeamSearchInstance
,
...
@@ -94,6 +94,11 @@ if TYPE_CHECKING:
...
@@ -94,6 +94,11 @@ if TYPE_CHECKING:
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
_O
=
TypeVar
(
"_O"
,
bound
=
RequestOutput
|
PoolingRequestOutput
,
default
=
RequestOutput
|
PoolingRequestOutput
,
)
_P
=
TypeVar
(
"_P"
,
bound
=
SamplingParams
|
PoolingParams
|
None
)
_P
=
TypeVar
(
"_P"
,
bound
=
SamplingParams
|
PoolingParams
|
None
)
_R
=
TypeVar
(
"_R"
,
default
=
Any
)
_R
=
TypeVar
(
"_R"
,
default
=
Any
)
...
@@ -447,17 +452,16 @@ class LLM:
...
@@ -447,17 +452,16 @@ class LLM:
if
sampling_params
is
None
:
if
sampling_params
is
None
:
sampling_params
=
self
.
get_default_sampling_params
()
sampling_params
=
self
.
get_default_sampling_params
()
outputs
=
self
.
_run_completion
(
return
self
.
_run_completion
(
prompts
=
prompts
,
prompts
=
prompts
,
params
=
sampling_params
,
params
=
sampling_params
,
output_type
=
RequestOutput
,
use_tqdm
=
use_tqdm
,
use_tqdm
=
use_tqdm
,
lora_request
=
lora_request
,
lora_request
=
lora_request
,
tokenization_kwargs
=
tokenization_kwargs
,
tokenization_kwargs
=
tokenization_kwargs
,
priority
=
priority
,
priority
=
priority
,
)
)
return
self
.
engine_class
.
validate_outputs
(
outputs
,
RequestOutput
)
def
enqueue
(
def
enqueue
(
self
,
self
,
prompts
:
PromptType
|
Sequence
[
PromptType
],
prompts
:
PromptType
|
Sequence
[
PromptType
],
...
@@ -524,23 +528,43 @@ class LLM:
...
@@ -524,23 +528,43 @@ class LLM:
return
request_ids
return
request_ids
@
overload
def
wait_for_completion
(
def
wait_for_completion
(
self
,
self
,
*
,
use_tqdm
:
bool
|
Callable
[...,
tqdm
]
=
True
,
use_tqdm
:
bool
|
Callable
[...,
tqdm
]
=
True
,
)
->
list
[
RequestOutput
]:
)
->
list
[
RequestOutput
|
PoolingRequestOutput
]:
...
@
overload
def
wait_for_completion
(
self
,
output_type
:
type
[
_O
]
|
tuple
[
type
[
_O
],
...],
*
,
use_tqdm
:
bool
|
Callable
[...,
tqdm
]
=
True
,
)
->
list
[
_O
]:
...
def
wait_for_completion
(
self
,
output_type
:
type
[
Any
]
|
tuple
[
type
[
Any
],
...]
|
None
=
None
,
*
,
use_tqdm
:
bool
|
Callable
[...,
tqdm
]
=
True
,
)
->
list
[
Any
]:
"""Wait for all enqueued requests to complete and return results.
"""Wait for all enqueued requests to complete and return results.
This method processes all requests currently in the engine queue
This method processes all requests currently in the engine queue
and returns their outputs. Use after enqueue() to get results.
and returns their outputs. Use after enqueue() to get results.
Args:
Args:
output_type: The expected output type, defaults to RequestOutput.
use_tqdm: If True, shows a tqdm progress bar.
use_tqdm: If True, shows a tqdm progress bar.
Returns:
Returns:
A list of
RequestO
utput objects for all completed requests.
A list of
o
utput objects for all completed requests.
"""
"""
outputs
=
self
.
_run_engine
(
use_tqdm
=
use_tqdm
)
if
output_type
is
None
:
return
self
.
engine_class
.
validate_outputs
(
outputs
,
RequestOutput
)
output_type
=
(
RequestOutput
,
PoolingRequestOutput
)
return
self
.
_run_engine
(
output_type
,
use_tqdm
=
use_tqdm
)
def
_resolve_mm_lora
(
def
_resolve_mm_lora
(
self
,
self
,
...
@@ -744,13 +768,13 @@ class LLM:
...
@@ -744,13 +768,13 @@ class LLM:
# only runs for one step
# only runs for one step
# we don't need to use tqdm here
# we don't need to use tqdm here
raw_
output
=
self
.
_render_and_run_requests
(
output
=
self
.
_render_and_run_requests
(
prompts
=
(
beam
.
get_prompt
()
for
beam
in
all_beams
),
prompts
=
(
beam
.
get_prompt
()
for
beam
in
all_beams
),
params
=
self
.
_params_to_seq
(
sampling_params
,
len
(
all_beams
)),
params
=
self
.
_params_to_seq
(
sampling_params
,
len
(
all_beams
)),
output_type
=
RequestOutput
,
lora_requests
=
[
beam
.
lora_request
for
beam
in
all_beams
],
lora_requests
=
[
beam
.
lora_request
for
beam
in
all_beams
],
use_tqdm
=
False
,
use_tqdm
=
False
,
)
)
output
=
self
.
engine_class
.
validate_outputs
(
raw_output
,
RequestOutput
)
for
(
start
,
end
),
instance
in
zip
(
for
(
start
,
end
),
instance
in
zip
(
instance_start_and_end
,
instances_batch
instance_start_and_end
,
instances_batch
...
@@ -987,9 +1011,10 @@ class LLM:
...
@@ -987,9 +1011,10 @@ class LLM:
if
sampling_params
is
None
:
if
sampling_params
is
None
:
sampling_params
=
self
.
get_default_sampling_params
()
sampling_params
=
self
.
get_default_sampling_params
()
outputs
=
self
.
_run_chat
(
return
self
.
_run_chat
(
messages
=
messages
,
messages
=
messages
,
params
=
sampling_params
,
params
=
sampling_params
,
output_type
=
RequestOutput
,
use_tqdm
=
use_tqdm
,
use_tqdm
=
use_tqdm
,
lora_request
=
lora_request
,
lora_request
=
lora_request
,
chat_template
=
chat_template
,
chat_template
=
chat_template
,
...
@@ -1002,8 +1027,6 @@ class LLM:
...
@@ -1002,8 +1027,6 @@ class LLM:
mm_processor_kwargs
=
mm_processor_kwargs
,
mm_processor_kwargs
=
mm_processor_kwargs
,
)
)
return
self
.
engine_class
.
validate_outputs
(
outputs
,
RequestOutput
)
def
encode
(
def
encode
(
self
,
self
,
prompts
:
PromptType
|
Sequence
[
PromptType
]
|
DataPrompt
,
prompts
:
PromptType
|
Sequence
[
PromptType
]
|
DataPrompt
,
...
@@ -1135,19 +1158,16 @@ class LLM:
...
@@ -1135,19 +1158,16 @@ class LLM:
outputs
=
self
.
_run_completion
(
outputs
=
self
.
_run_completion
(
prompts
=
prompts_seq
,
prompts
=
prompts_seq
,
params
=
params_seq
,
params
=
params_seq
,
output_type
=
PoolingRequestOutput
,
use_tqdm
=
use_tqdm
,
use_tqdm
=
use_tqdm
,
lora_request
=
lora_request
,
lora_request
=
lora_request
,
tokenization_kwargs
=
tokenization_kwargs
,
tokenization_kwargs
=
tokenization_kwargs
,
)
)
model_outputs
=
self
.
engine_class
.
validate_outputs
(
outputs
,
PoolingRequestOutput
)
if
use_io_processor
:
if
use_io_processor
:
# get the post-processed model outputs
# get the post-processed model outputs
assert
self
.
io_processor
is
not
None
assert
self
.
io_processor
is
not
None
processed_outputs
=
self
.
io_processor
.
post_process
(
model_
outputs
)
processed_outputs
=
self
.
io_processor
.
post_process
(
outputs
)
return
[
return
[
PoolingRequestOutput
[
Any
](
PoolingRequestOutput
[
Any
](
...
@@ -1160,8 +1180,8 @@ class LLM:
...
@@ -1160,8 +1180,8 @@ class LLM:
finished
=
True
,
finished
=
True
,
)
)
]
]
else
:
return
model_
outputs
return
outputs
def
embed
(
def
embed
(
self
,
self
,
...
@@ -1353,8 +1373,7 @@ class LLM:
...
@@ -1353,8 +1373,7 @@ class LLM:
embed_2
=
encoded_output_2
,
embed_2
=
encoded_output_2
,
)
)
items
=
self
.
engine_class
.
validate_outputs
(
scores
,
PoolingRequestOutput
)
return
[
ScoringRequestOutput
.
from_base
(
item
)
for
item
in
scores
]
return
[
ScoringRequestOutput
.
from_base
(
item
)
for
item
in
items
]
def
_late_interaction_score
(
def
_late_interaction_score
(
self
,
self
,
...
@@ -1393,7 +1412,7 @@ class LLM:
...
@@ -1393,7 +1412,7 @@ class LLM:
)
)
text_2
.
append
(
text
)
text_2
.
append
(
text
)
encoded_output
:
list
[
PoolingRequestOutput
]
=
self
.
encode
(
encoded_output
=
self
.
encode
(
text_1
+
text_2
,
text_1
+
text_2
,
use_tqdm
=
use_tqdm
,
use_tqdm
=
use_tqdm
,
lora_request
=
lora_request
,
lora_request
=
lora_request
,
...
@@ -1402,8 +1421,8 @@ class LLM:
...
@@ -1402,8 +1421,8 @@ class LLM:
tokenization_kwargs
=
tokenization_kwargs
,
tokenization_kwargs
=
tokenization_kwargs
,
)
)
encoded_output_1
:
list
[
PoolingRequestOutput
]
=
encoded_output
[
0
:
len
(
text_1
)]
encoded_output_1
=
encoded_output
[
0
:
len
(
text_1
)]
encoded_output_2
:
list
[
PoolingRequestOutput
]
=
encoded_output
[
len
(
text_1
)
:]
encoded_output_2
=
encoded_output
[
len
(
text_1
)
:]
if
len
(
encoded_output_1
)
==
1
:
if
len
(
encoded_output_1
)
==
1
:
encoded_output_1
=
encoded_output_1
*
len
(
encoded_output_2
)
encoded_output_1
=
encoded_output_1
*
len
(
encoded_output_2
)
...
@@ -1434,8 +1453,7 @@ class LLM:
...
@@ -1434,8 +1453,7 @@ class LLM:
)
)
)
)
items
=
self
.
engine_class
.
validate_outputs
(
scores
,
PoolingRequestOutput
)
return
[
ScoringRequestOutput
.
from_base
(
item
)
for
item
in
scores
]
return
[
ScoringRequestOutput
.
from_base
(
item
)
for
item
in
items
]
def
_cross_encoding_score
(
def
_cross_encoding_score
(
self
,
self
,
...
@@ -1491,13 +1509,12 @@ class LLM:
...
@@ -1491,13 +1509,12 @@ class LLM:
outputs
=
self
.
_run_completion
(
outputs
=
self
.
_run_completion
(
prompts
=
prompts
,
prompts
=
prompts
,
params
=
pooling_params_list
,
params
=
pooling_params_list
,
output_type
=
PoolingRequestOutput
,
use_tqdm
=
use_tqdm
,
use_tqdm
=
use_tqdm
,
lora_request
=
lora_request
,
lora_request
=
lora_request
,
)
)
items
=
self
.
engine_class
.
validate_outputs
(
outputs
,
PoolingRequestOutput
)
return
[
ScoringRequestOutput
.
from_base
(
item
)
for
item
in
outputs
]
return
[
ScoringRequestOutput
.
from_base
(
item
)
for
item
in
items
]
def
score
(
def
score
(
self
,
self
,
...
@@ -1759,6 +1776,7 @@ class LLM:
...
@@ -1759,6 +1776,7 @@ class LLM:
params
:
SamplingParams
params
:
SamplingParams
|
PoolingParams
|
PoolingParams
|
Sequence
[
SamplingParams
|
PoolingParams
],
|
Sequence
[
SamplingParams
|
PoolingParams
],
output_type
:
type
[
_O
],
*
,
*
,
use_tqdm
:
bool
|
Callable
[...,
tqdm
]
=
True
,
use_tqdm
:
bool
|
Callable
[...,
tqdm
]
=
True
,
lora_request
:
Sequence
[
LoRARequest
]
|
LoRARequest
|
None
=
None
,
lora_request
:
Sequence
[
LoRARequest
]
|
LoRARequest
|
None
=
None
,
...
@@ -1790,6 +1808,7 @@ class LLM:
...
@@ -1790,6 +1808,7 @@ class LLM:
)
)
),
),
params
=
seq_params
,
params
=
seq_params
,
output_type
=
output_type
,
use_tqdm
=
use_tqdm
,
use_tqdm
=
use_tqdm
,
lora_requests
=
seq_lora_requests
,
lora_requests
=
seq_lora_requests
,
priorities
=
seq_priority
,
priorities
=
seq_priority
,
...
@@ -1802,6 +1821,7 @@ class LLM:
...
@@ -1802,6 +1821,7 @@ class LLM:
params
:
SamplingParams
params
:
SamplingParams
|
PoolingParams
|
PoolingParams
|
Sequence
[
SamplingParams
|
PoolingParams
],
|
Sequence
[
SamplingParams
|
PoolingParams
],
output_type
:
type
[
_O
],
*
,
*
,
use_tqdm
:
bool
|
Callable
[...,
tqdm
]
=
True
,
use_tqdm
:
bool
|
Callable
[...,
tqdm
]
=
True
,
lora_request
:
Sequence
[
LoRARequest
]
|
LoRARequest
|
None
=
None
,
lora_request
:
Sequence
[
LoRARequest
]
|
LoRARequest
|
None
=
None
,
...
@@ -1848,6 +1868,7 @@ class LLM:
...
@@ -1848,6 +1868,7 @@ class LLM:
)
)
),
),
params
=
seq_params
,
params
=
seq_params
,
output_type
=
output_type
,
lora_requests
=
seq_lora_requests
,
lora_requests
=
seq_lora_requests
,
use_tqdm
=
use_tqdm
,
use_tqdm
=
use_tqdm
,
)
)
...
@@ -1856,6 +1877,7 @@ class LLM:
...
@@ -1856,6 +1877,7 @@ class LLM:
self
,
self
,
prompts
:
Iterable
[
ProcessorInputs
],
prompts
:
Iterable
[
ProcessorInputs
],
params
:
Sequence
[
SamplingParams
|
PoolingParams
],
params
:
Sequence
[
SamplingParams
|
PoolingParams
],
output_type
:
type
[
_O
],
*
,
*
,
lora_requests
:
Sequence
[
LoRARequest
|
None
]
|
None
=
None
,
lora_requests
:
Sequence
[
LoRARequest
|
None
]
|
None
=
None
,
priorities
:
Sequence
[
int
]
|
None
=
None
,
priorities
:
Sequence
[
int
]
|
None
=
None
,
...
@@ -1878,7 +1900,7 @@ class LLM:
...
@@ -1878,7 +1900,7 @@ class LLM:
priorities
=
priorities
,
priorities
=
priorities
,
)
)
return
self
.
_run_engine
(
use_tqdm
=
use_tqdm
)
return
self
.
_run_engine
(
output_type
,
use_tqdm
=
use_tqdm
)
def
_render_and_add_requests
(
def
_render_and_add_requests
(
self
,
self
,
...
@@ -1932,9 +1954,10 @@ class LLM:
...
@@ -1932,9 +1954,10 @@ class LLM:
def
_run_engine
(
def
_run_engine
(
self
,
self
,
output_type
:
type
[
_O
]
|
tuple
[
type
[
_O
],
...],
*
,
*
,
use_tqdm
:
bool
|
Callable
[...,
tqdm
]
=
True
,
use_tqdm
:
bool
|
Callable
[...,
tqdm
]
=
True
,
)
->
list
[
RequestOutput
|
PoolingRequestOutput
]:
)
->
list
[
_O
]:
# Initialize tqdm.
# Initialize tqdm.
if
use_tqdm
:
if
use_tqdm
:
num_requests
=
self
.
llm_engine
.
get_num_unfinished_requests
()
num_requests
=
self
.
llm_engine
.
get_num_unfinished_requests
()
...
@@ -1947,14 +1970,15 @@ class LLM:
...
@@ -1947,14 +1970,15 @@ class LLM:
)
)
# Run the engine.
# Run the engine.
outputs
:
list
[
RequestOutput
|
PoolingRequestOutput
]
=
[]
outputs
:
list
[
_O
]
=
[]
total_in_toks
=
0
total_in_toks
=
0
total_out_toks
=
0
total_out_toks
=
0
while
self
.
llm_engine
.
has_unfinished_requests
():
while
self
.
llm_engine
.
has_unfinished_requests
():
step_outputs
=
self
.
llm_engine
.
step
()
step_outputs
=
self
.
llm_engine
.
step
()
for
output
in
step_outputs
:
for
output
in
step_outputs
:
assert
isinstance
(
output
,
output_type
)
if
output
.
finished
:
if
output
.
finished
:
outputs
.
append
(
output
)
outputs
.
append
(
output
)
# type: ignore[arg-type]
if
use_tqdm
:
if
use_tqdm
:
if
isinstance
(
output
,
RequestOutput
):
if
isinstance
(
output
,
RequestOutput
):
# Calculate tokens only for RequestOutput
# Calculate tokens only for RequestOutput
...
...
vllm/v1/engine/llm_engine.py
View file @
ac900c89
...
@@ -199,10 +199,6 @@ class LLMEngine:
...
@@ -199,10 +199,6 @@ class LLMEngine:
self
.
should_execute_dummy_batch
=
True
self
.
should_execute_dummy_batch
=
True
return
aggregated_has_unfinished
return
aggregated_has_unfinished
@
classmethod
def
validate_outputs
(
cls
,
outputs
,
output_type
):
return
outputs
def
get_supported_tasks
(
self
)
->
tuple
[
SupportedTask
,
...]:
def
get_supported_tasks
(
self
)
->
tuple
[
SupportedTask
,
...]:
if
not
hasattr
(
self
,
"_supported_tasks"
):
if
not
hasattr
(
self
,
"_supported_tasks"
):
# Cache the result
# Cache the result
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment