Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4aed506b
Unverified
Commit
4aed506b
authored
Oct 14, 2025
by
Nick Hill
Committed by
GitHub
Oct 14, 2025
Browse files
[Core] Streamline some structured output related code (#26737)
Signed-off-by:
Nick Hill
<
nhill@redhat.com
>
parent
a86b4c58
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
119 additions
and
136 deletions
+119
-136
tests/v1/core/test_scheduler.py
tests/v1/core/test_scheduler.py
+8
-10
tests/v1/kv_connector/unit/test_kv_connector_lifecyle.py
tests/v1/kv_connector/unit/test_kv_connector_lifecyle.py
+1
-1
tests/v1/tpu/worker/test_tpu_model_runner.py
tests/v1/tpu/worker/test_tpu_model_runner.py
+12
-12
tests/v1/worker/test_gpu_model_runner.py
tests/v1/worker/test_gpu_model_runner.py
+12
-12
vllm/v1/core/sched/output.py
vllm/v1/core/sched/output.py
+2
-3
vllm/v1/core/sched/scheduler.py
vllm/v1/core/sched/scheduler.py
+30
-35
vllm/v1/request.py
vllm/v1/request.py
+8
-10
vllm/v1/structured_output/__init__.py
vllm/v1/structured_output/__init__.py
+15
-17
vllm/v1/structured_output/backend_guidance.py
vllm/v1/structured_output/backend_guidance.py
+1
-1
vllm/v1/structured_output/request.py
vllm/v1/structured_output/request.py
+25
-19
vllm/v1/structured_output/utils.py
vllm/v1/structured_output/utils.py
+2
-7
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+2
-4
vllm/v1/worker/tpu_model_runner.py
vllm/v1/worker/tpu_model_runner.py
+1
-5
No files found.
tests/v1/core/test_scheduler.py
View file @
4aed506b
...
...
@@ -30,7 +30,6 @@ from vllm.v1.kv_cache_interface import (
from
vllm.v1.outputs
import
DraftTokenIds
,
ModelRunnerOutput
from
vllm.v1.request
import
Request
,
RequestStatus
from
vllm.v1.structured_output
import
StructuredOutputManager
from
vllm.v1.structured_output.request
import
StructuredOutputRequest
from
.utils
import
EOS_TOKEN_ID
,
create_requests
,
create_scheduler
...
...
@@ -335,10 +334,10 @@ def test_stop_via_update_from_output():
requests
[
0
].
request_id
:
[],
requests
[
1
].
request_id
:
[
10
],
},
num_common_prefix_blocks
=
0
,
num_common_prefix_blocks
=
[]
,
finished_req_ids
=
set
(),
free_encoder_mm_hashes
=
[],
structured_output_request_ids
=
{}
,
structured_output_request_ids
=
[]
,
grammar_bitmask
=
None
,
)
...
...
@@ -383,10 +382,10 @@ def test_stop_via_update_from_output():
requests
[
0
].
request_id
:
[
10
,
42
],
requests
[
1
].
request_id
:
[
13
],
},
num_common_prefix_blocks
=
0
,
num_common_prefix_blocks
=
[]
,
finished_req_ids
=
set
(),
free_encoder_mm_hashes
=
[],
structured_output_request_ids
=
{}
,
structured_output_request_ids
=
[]
,
grammar_bitmask
=
None
,
)
...
...
@@ -429,10 +428,10 @@ def test_stop_via_update_from_output():
requests
[
0
].
request_id
:
[
10
,
11
],
requests
[
1
].
request_id
:
[],
},
num_common_prefix_blocks
=
0
,
num_common_prefix_blocks
=
[]
,
finished_req_ids
=
set
(),
free_encoder_mm_hashes
=
[],
structured_output_request_ids
=
{}
,
structured_output_request_ids
=
[]
,
grammar_bitmask
=
None
,
)
...
...
@@ -470,10 +469,10 @@ def test_stop_via_update_from_output():
total_num_scheduled_tokens
=
3
,
scheduled_encoder_inputs
=
{},
scheduled_spec_decode_tokens
=
{
requests
[
0
].
request_id
:
[
EOS_TOKEN_ID
,
10
]},
num_common_prefix_blocks
=
0
,
num_common_prefix_blocks
=
[]
,
finished_req_ids
=
set
(),
free_encoder_mm_hashes
=
[],
structured_output_request_ids
=
{}
,
structured_output_request_ids
=
[]
,
grammar_bitmask
=
None
,
)
...
...
@@ -1941,7 +1940,6 @@ def test_schedule_skip_tokenizer_init_structured_output_request():
sampling_params
=
sampling_params
,
pooling_params
=
None
,
eos_token_id
=
EOS_TOKEN_ID
,
structured_output_request
=
StructuredOutputRequest
(
sampling_params
),
)
scheduler
.
add_request
(
request
)
output
=
scheduler
.
schedule
()
...
...
tests/v1/kv_connector/unit/test_kv_connector_lifecyle.py
View file @
4aed506b
...
...
@@ -26,7 +26,7 @@ def _make_empty_scheduler_output():
num_common_prefix_blocks
=
[],
finished_req_ids
=
set
(),
free_encoder_mm_hashes
=
[],
structured_output_request_ids
=
{}
,
structured_output_request_ids
=
[]
,
grammar_bitmask
=
None
,
kv_connector_metadata
=
SharedStorageConnectorMetadata
(),
)
...
...
tests/v1/tpu/worker/test_tpu_model_runner.py
View file @
4aed506b
...
...
@@ -89,10 +89,10 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
total_num_scheduled_tokens
=
total_num_scheduled_tokens
,
scheduled_spec_decode_tokens
=
{},
scheduled_encoder_inputs
=
{},
num_common_prefix_blocks
=
0
,
num_common_prefix_blocks
=
[]
,
finished_req_ids
=
set
(),
free_encoder_mm_hashes
=
[],
structured_output_request_ids
=
{}
,
structured_output_request_ids
=
[]
,
grammar_bitmask
=
None
,
)
...
...
@@ -168,10 +168,10 @@ def test_update_states_request_finished(model_runner):
total_num_scheduled_tokens
=
0
,
scheduled_spec_decode_tokens
=
{},
scheduled_encoder_inputs
=
{},
num_common_prefix_blocks
=
0
,
num_common_prefix_blocks
=
[]
,
finished_req_ids
=
{
req_id
},
free_encoder_mm_hashes
=
[],
structured_output_request_ids
=
{}
,
structured_output_request_ids
=
[]
,
grammar_bitmask
=
None
,
)
...
...
@@ -198,10 +198,10 @@ def test_update_states_request_resumed(model_runner):
total_num_scheduled_tokens
=
0
,
scheduled_spec_decode_tokens
=
{},
scheduled_encoder_inputs
=
{},
num_common_prefix_blocks
=
0
,
num_common_prefix_blocks
=
[]
,
finished_req_ids
=
set
(),
free_encoder_mm_hashes
=
[],
structured_output_request_ids
=
{}
,
structured_output_request_ids
=
[]
,
grammar_bitmask
=
None
,
)
...
...
@@ -225,10 +225,10 @@ def test_update_states_request_resumed(model_runner):
total_num_scheduled_tokens
=
1
,
scheduled_spec_decode_tokens
=
{},
scheduled_encoder_inputs
=
{},
num_common_prefix_blocks
=
0
,
num_common_prefix_blocks
=
[]
,
finished_req_ids
=
set
(),
free_encoder_mm_hashes
=
[],
structured_output_request_ids
=
{}
,
structured_output_request_ids
=
[]
,
grammar_bitmask
=
None
,
)
...
...
@@ -256,10 +256,10 @@ def test_update_states_no_changes(model_runner):
total_num_scheduled_tokens
=
1
,
scheduled_spec_decode_tokens
=
{},
scheduled_encoder_inputs
=
{},
num_common_prefix_blocks
=
0
,
num_common_prefix_blocks
=
[]
,
finished_req_ids
=
set
(),
free_encoder_mm_hashes
=
[],
structured_output_request_ids
=
{}
,
structured_output_request_ids
=
[]
,
grammar_bitmask
=
None
,
)
...
...
@@ -291,10 +291,10 @@ def test_update_states_request_unscheduled(model_runner):
total_num_scheduled_tokens
=
1
,
scheduled_spec_decode_tokens
=
{},
scheduled_encoder_inputs
=
{},
num_common_prefix_blocks
=
0
,
num_common_prefix_blocks
=
[]
,
finished_req_ids
=
set
(),
free_encoder_mm_hashes
=
[],
structured_output_request_ids
=
{}
,
structured_output_request_ids
=
[]
,
grammar_bitmask
=
None
,
)
...
...
tests/v1/worker/test_gpu_model_runner.py
View file @
4aed506b
...
...
@@ -146,10 +146,10 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
total_num_scheduled_tokens
=
total_num_scheduled_tokens
,
scheduled_spec_decode_tokens
=
{},
scheduled_encoder_inputs
=
{},
num_common_prefix_blocks
=
0
,
num_common_prefix_blocks
=
[]
,
finished_req_ids
=
set
(),
free_encoder_mm_hashes
=
[],
structured_output_request_ids
=
{}
,
structured_output_request_ids
=
[]
,
grammar_bitmask
=
None
,
)
...
...
@@ -212,10 +212,10 @@ def test_update_states_request_finished(model_runner, dist_init):
total_num_scheduled_tokens
=
0
,
scheduled_spec_decode_tokens
=
{},
scheduled_encoder_inputs
=
{},
num_common_prefix_blocks
=
0
,
num_common_prefix_blocks
=
[]
,
finished_req_ids
=
{
req_id
},
free_encoder_mm_hashes
=
[],
structured_output_request_ids
=
{}
,
structured_output_request_ids
=
[]
,
grammar_bitmask
=
None
,
)
...
...
@@ -244,10 +244,10 @@ def test_update_states_request_resumed(model_runner, dist_init):
total_num_scheduled_tokens
=
0
,
scheduled_spec_decode_tokens
=
{},
scheduled_encoder_inputs
=
{},
num_common_prefix_blocks
=
0
,
num_common_prefix_blocks
=
[]
,
finished_req_ids
=
set
(),
free_encoder_mm_hashes
=
[],
structured_output_request_ids
=
{}
,
structured_output_request_ids
=
[]
,
grammar_bitmask
=
None
,
)
...
...
@@ -273,10 +273,10 @@ def test_update_states_request_resumed(model_runner, dist_init):
total_num_scheduled_tokens
=
1
,
scheduled_spec_decode_tokens
=
{},
scheduled_encoder_inputs
=
{},
num_common_prefix_blocks
=
0
,
num_common_prefix_blocks
=
[]
,
finished_req_ids
=
set
(),
free_encoder_mm_hashes
=
[],
structured_output_request_ids
=
{}
,
structured_output_request_ids
=
[]
,
grammar_bitmask
=
None
,
)
...
...
@@ -366,10 +366,10 @@ def test_update_states_no_changes(model_runner, dist_init):
total_num_scheduled_tokens
=
1
,
scheduled_spec_decode_tokens
=
{},
scheduled_encoder_inputs
=
{},
num_common_prefix_blocks
=
0
,
num_common_prefix_blocks
=
[]
,
finished_req_ids
=
set
(),
free_encoder_mm_hashes
=
[],
structured_output_request_ids
=
{}
,
structured_output_request_ids
=
[]
,
grammar_bitmask
=
None
,
)
...
...
@@ -403,10 +403,10 @@ def test_update_states_request_unscheduled(model_runner, dist_init):
total_num_scheduled_tokens
=
1
,
scheduled_spec_decode_tokens
=
{},
scheduled_encoder_inputs
=
{},
num_common_prefix_blocks
=
0
,
num_common_prefix_blocks
=
[]
,
finished_req_ids
=
set
(),
free_encoder_mm_hashes
=
[],
structured_output_request_ids
=
{}
,
structured_output_request_ids
=
[]
,
grammar_bitmask
=
None
,
)
...
...
vllm/v1/core/sched/output.py
View file @
4aed506b
...
...
@@ -165,9 +165,8 @@ class SchedulerOutput:
# freed from the encoder cache.
free_encoder_mm_hashes
:
list
[
str
]
# Dict of request ids to their index within the batch
# for filling the next token bitmask
structured_output_request_ids
:
dict
[
str
,
int
]
# ids of structured outputs requests included in the bitmask, in order.
structured_output_request_ids
:
list
[
str
]
# the bitmask for the whole batch
grammar_bitmask
:
"npt.NDArray[np.int32] | None"
...
...
vllm/v1/core/sched/scheduler.py
View file @
4aed506b
...
...
@@ -5,7 +5,7 @@ import itertools
import
time
from
collections
import
defaultdict
from
collections.abc
import
Iterable
from
typing
import
Any
from
typing
import
TYPE_CHECKING
,
Any
from
vllm.config
import
VllmConfig
from
vllm.distributed.kv_events
import
EventPublisherFactory
,
KVEventBatch
...
...
@@ -34,6 +34,10 @@ from vllm.v1.request import Request, RequestStatus
from
vllm.v1.spec_decode.metrics
import
SpecDecodingStats
from
vllm.v1.structured_output
import
StructuredOutputManager
if
TYPE_CHECKING
:
import
numpy
as
np
import
numpy.typing
as
npt
logger
=
init_logger
(
__name__
)
...
...
@@ -608,11 +612,8 @@ class Scheduler(SchedulerInterface):
scheduled_spec_decode_tokens
,
req_to_new_blocks
,
)
scheduled_requests
=
(
scheduled_new_reqs
+
scheduled_running_reqs
+
scheduled_resumed_reqs
)
structured_output_request_ids
,
grammar_bitmask
=
self
.
get_grammar_bitmask
(
scheduled_
requests
,
scheduled_spec_decode_tokens
num_
scheduled_
tokens
.
keys
()
,
scheduled_spec_decode_tokens
)
scheduler_output
=
SchedulerOutput
(
scheduled_new_reqs
=
new_reqs_data
,
...
...
@@ -876,32 +877,28 @@ class Scheduler(SchedulerInterface):
def
get_grammar_bitmask
(
self
,
requests
:
list
[
Reque
st
],
scheduled_request_ids
:
Iterable
[
st
r
],
scheduled_spec_decode_tokens
:
dict
[
str
,
list
[
int
]],
):
# NOTE: structured_output_request_ids maps
# a request's (request that uses structured output)
# request_id to its index in the batch.
# This will help us determine to slice the grammar bitmask
# and only applies valid mask for requests that
# uses structured decoding.
structured_output_request_ids
:
dict
[
str
,
int
]
=
{}
for
i
,
req
in
enumerate
(
requests
):
if
req
.
use_structured_output
:
# PERF: in case of chunked prefill,
# request might not include any new tokens.
# Therefore, we might introduce some additional
# cycle to fill in the bitmask, which could be a big no-op.
structured_output_request_ids
[
req
.
request_id
]
=
i
)
->
tuple
[
list
[
str
],
"npt.NDArray[np.int32] | None"
]:
# Collect list of scheduled request ids that use structured output.
# The corresponding rows of the bitmask will be in this order.
# PERF: in case of chunked prefill,
# request might not include any new tokens.
# Therefore, we might introduce some additional
# cycle to fill in the bitmask, which could be a big no-op.
structured_output_request_ids
=
[
req_id
for
req_id
in
scheduled_request_ids
if
(
req
:
=
self
.
requests
.
get
(
req_id
))
and
req
.
use_structured_output
]
if
not
structured_output_request_ids
:
bitmask
=
None
else
:
bitmask
=
self
.
structured_output_manager
.
grammar_bitmask
(
self
.
requests
,
structured_output_request_ids
,
scheduled_spec_decode_tokens
,
)
return
structured_output_request_ids
,
None
bitmask
=
self
.
structured_output_manager
.
grammar_bitmask
(
self
.
requests
,
structured_output_request_ids
,
scheduled_spec_decode_tokens
,
)
return
structured_output_request_ids
,
bitmask
def
update_from_output
(
...
...
@@ -1013,12 +1010,10 @@ class Scheduler(SchedulerInterface):
new_logprobs
=
logprobs
.
slice
(
req_index
,
req_index
+
1
)
if
new_token_ids
and
self
.
structured_output_manager
.
should_advance
(
request
):
# NOTE: structured_output_request
# should not be None if use_structured_output, we have
# checked above, so safe to ignore type warning
request
.
structured_output_request
.
grammar
.
accept_tokens
(
# type: ignore[union-attr]
req_id
,
new_token_ids
)
struct_output_request
=
request
.
structured_output_request
assert
struct_output_request
is
not
None
assert
struct_output_request
.
grammar
is
not
None
struct_output_request
.
grammar
.
accept_tokens
(
req_id
,
new_token_ids
)
if
num_nans_in_logits
is
not
None
and
req_id
in
num_nans_in_logits
:
request
.
num_nans_in_logits
=
num_nans_in_logits
[
req_id
]
...
...
vllm/v1/request.py
View file @
4aed506b
...
...
@@ -40,7 +40,6 @@ class Request:
prompt_embeds
:
torch
.
Tensor
|
None
=
None
,
mm_features
:
list
[
MultiModalFeatureSpec
]
|
None
=
None
,
lora_request
:
Optional
[
"LoRARequest"
]
=
None
,
structured_output_request
:
Optional
[
"StructuredOutputRequest"
]
=
None
,
cache_salt
:
str
|
None
=
None
,
priority
:
int
=
0
,
trace_headers
:
Mapping
[
str
,
str
]
|
None
=
None
,
...
...
@@ -54,11 +53,12 @@ class Request:
# Because of LoRA, the eos token id can be different for each request.
self
.
eos_token_id
=
eos_token_id
self
.
lora_request
=
lora_request
self
.
structured_output_request
=
structured_output_request
self
.
structured_output_request
=
StructuredOutputRequest
.
from_sampling_params
(
sampling_params
)
self
.
arrival_time
=
arrival_time
if
arrival_time
is
not
None
else
time
.
time
()
self
.
status
=
RequestStatus
.
WAITING
self
.
use_structured_output
=
False
self
.
events
:
list
[
EngineCoreEvent
]
=
[]
self
.
stop_reason
:
int
|
str
|
None
=
None
...
...
@@ -72,9 +72,8 @@ class Request:
# Generative models.
assert
sampling_params
.
max_tokens
is
not
None
self
.
max_tokens
=
sampling_params
.
max_tokens
if
s
ampling_params
.
structured_output
s
is
not
None
:
if
s
elf
.
structured_output
_request
is
not
None
:
self
.
status
=
RequestStatus
.
WAITING_FOR_FSM
self
.
use_structured_output
=
True
if
sampling_params
.
extra_args
is
not
None
:
self
.
kv_transfer_params
=
sampling_params
.
extra_args
.
get
(
...
...
@@ -145,11 +144,6 @@ class Request:
eos_token_id
=
request
.
eos_token_id
,
arrival_time
=
request
.
arrival_time
,
lora_request
=
request
.
lora_request
,
structured_output_request
=
StructuredOutputRequest
(
sampling_params
=
request
.
sampling_params
)
if
request
.
sampling_params
else
None
,
cache_salt
=
request
.
cache_salt
,
priority
=
request
.
priority
,
trace_headers
=
request
.
trace_headers
,
...
...
@@ -170,6 +164,10 @@ class Request:
if
self
.
get_hash_new_full_blocks
is
not
None
:
self
.
block_hashes
.
extend
(
self
.
get_hash_new_full_blocks
())
@
property
def
use_structured_output
(
self
)
->
bool
:
return
self
.
structured_output_request
is
not
None
@
property
def
is_output_corrupted
(
self
)
->
bool
:
return
self
.
num_nans_in_logits
>
0
...
...
vllm/v1/structured_output/__init__.py
View file @
4aed506b
...
...
@@ -167,7 +167,7 @@ class StructuredOutputManager:
def
grammar_bitmask
(
self
,
requests
:
dict
[
str
,
Request
],
structured_output_request_ids
:
dic
t
[
str
,
int
],
structured_output_request_ids
:
lis
t
[
str
],
scheduled_spec_decode_tokens
:
dict
[
str
,
list
[
int
]],
)
->
"npt.NDArray[np.int32] | None"
:
# Prepare the structured output bitmask for this batch.
...
...
@@ -196,17 +196,16 @@ class StructuredOutputManager:
# masks for each request, one for each possible bonus token position.
# These are stored inline in the tensor and unpacked by the gpu runner.
cumulative_index
=
0
ordered_seq
=
sorted
(
structured_output_request_ids
.
items
(),
key
=
lambda
x
:
x
[
1
])
# Optimized parallel filling of bitmasks for
# non-spec, large-batch-size cases
if
(
len
(
ordered_seq
)
>
self
.
fill_bitmask_parallel_threshold
len
(
structured_output_request_ids
)
>
self
.
fill_bitmask_parallel_threshold
and
max_num_spec_tokens
==
0
):
promises
=
[]
batch
=
[]
for
req_id
,
_
in
ordered_seq
:
for
req_id
in
structured_output_request_ids
:
request
=
requests
[
req_id
]
structured_output_request
=
request
.
structured_output_request
if
TYPE_CHECKING
:
...
...
@@ -230,7 +229,7 @@ class StructuredOutputManager:
promise
.
result
()
else
:
# Fallback to serial filling of bitmasks for small-batch-size cases
for
req_id
,
_
in
ordered_seq
:
for
req_id
in
structured_output_request_ids
:
request
=
requests
[
req_id
]
structured_output_request
=
request
.
structured_output_request
...
...
@@ -295,21 +294,20 @@ class StructuredOutputManager:
assert
request
.
structured_output_request
.
grammar
is
not
None
# by default, we should always advance
# for cases that don't use thinking mode.
if
self
.
reasoner
is
not
None
:
structured_req
=
request
.
structured_output_request
if
self
.
reasoner
is
None
:
return
True
if
structured_req
.
reasoning_ended
:
return
True
structured_req
=
request
.
structured_output_request
if
structured_req
.
reasoning_ended
:
return
True
# Check if reasoning ends in *this* step
if
self
.
reasoner
.
is_reasoning_end
(
request
.
all_token_ids
):
# Reasoning just ended, so we shouldn't advance til
# next pass
structured_req
.
reasoning_ended
=
True
# Check if reasoning ends in *this* step
if
self
.
reasoner
.
is_reasoning_end
(
request
.
all_token_ids
):
# Reasoning just ended, so we shouldn't advance til
# next pass
structured_req
.
reasoning_ended
=
True
return
False
else
:
return
True
return
False
def
clear_backend
(
self
)
->
None
:
if
self
.
backend
is
not
None
:
...
...
vllm/v1/structured_output/backend_guidance.py
View file @
4aed506b
...
...
@@ -252,7 +252,7 @@ def serialize_guidance_grammar(
def
validate_guidance_grammar
(
sampling_params
:
SamplingParams
,
tokenizer
:
llguidance
.
LLTokenizer
|
None
=
None
)
->
None
:
tp
,
grm
=
get_structured_output_key
(
sampling_params
)
tp
,
grm
=
get_structured_output_key
(
sampling_params
.
structured_outputs
)
guidance_grm
=
serialize_guidance_grammar
(
tp
,
grm
)
err
=
llguidance
.
LLMatcher
.
validate_grammar
(
guidance_grm
,
tokenizer
)
if
err
:
...
...
vllm/v1/structured_output/request.py
View file @
4aed506b
...
...
@@ -7,7 +7,7 @@ from concurrent.futures import Future
from
concurrent.futures._base
import
TimeoutError
from
typing
import
cast
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
,
StructuredOutputsParams
from
vllm.v1.structured_output.backend_types
import
(
StructuredOutputGrammar
,
StructuredOutputKey
,
...
...
@@ -17,10 +17,19 @@ from vllm.v1.structured_output.backend_types import (
@
dataclasses
.
dataclass
class
StructuredOutputRequest
:
sampling_params
:
Sampling
Params
params
:
StructuredOutputs
Params
_grammar
:
Future
[
StructuredOutputGrammar
]
|
StructuredOutputGrammar
|
None
=
None
reasoning_ended
:
bool
|
None
=
None
@
staticmethod
def
from_sampling_params
(
sampling_params
:
SamplingParams
|
None
,
)
->
"StructuredOutputRequest | None"
:
if
sampling_params
is
None
:
return
None
params
=
sampling_params
.
structured_outputs
return
StructuredOutputRequest
(
params
=
params
)
if
params
else
None
def
_check_grammar_completion
(
self
)
->
bool
:
# NOTE: We have to lazy import to gate circular imports
from
vllm.v1.request
import
RequestStatus
...
...
@@ -53,31 +62,28 @@ class StructuredOutputRequest:
@
functools
.
cached_property
def
structured_output_key
(
self
)
->
StructuredOutputKey
:
return
get_structured_output_key
(
self
.
sampling_
params
)
return
get_structured_output_key
(
self
.
params
)
def
get_structured_output_key
(
sampling_params
:
SamplingParams
)
->
StructuredOutputKey
:
params
=
sampling_params
.
structured_outputs
assert
params
is
not
None
,
"params can't be None."
def
get_structured_output_key
(
params
:
StructuredOutputsParams
)
->
StructuredOutputKey
:
if
params
.
json
is
not
None
:
if
not
isinstance
(
params
.
json
,
str
):
json_str
=
json
.
dumps
(
params
.
json
)
else
:
json_str
=
params
.
json
return
(
StructuredOutputOptions
.
JSON
,
json_str
)
el
if
params
.
json_object
:
return
(
StructuredOutputOptions
.
JSON_OBJECT
,
""
)
el
if
params
.
regex
is
not
None
:
return
(
StructuredOutputOptions
.
REGEX
,
params
.
regex
)
el
if
params
.
choice
is
not
None
:
return
StructuredOutputOptions
.
JSON
,
json_str
if
params
.
json_object
:
return
StructuredOutputOptions
.
JSON_OBJECT
,
""
if
params
.
regex
is
not
None
:
return
StructuredOutputOptions
.
REGEX
,
params
.
regex
if
params
.
choice
is
not
None
:
if
not
isinstance
(
params
.
choice
,
str
):
json_str
=
json
.
dumps
(
params
.
choice
)
else
:
json_str
=
params
.
choice
return
(
StructuredOutputOptions
.
CHOICE
,
json_str
)
elif
params
.
grammar
is
not
None
:
return
(
StructuredOutputOptions
.
GRAMMAR
,
params
.
grammar
)
elif
params
.
structural_tag
is
not
None
:
return
(
StructuredOutputOptions
.
STRUCTURAL_TAG
,
params
.
structural_tag
)
else
:
raise
ValueError
(
"No valid structured output parameter found"
)
return
StructuredOutputOptions
.
CHOICE
,
json_str
if
params
.
grammar
is
not
None
:
return
StructuredOutputOptions
.
GRAMMAR
,
params
.
grammar
if
params
.
structural_tag
is
not
None
:
return
StructuredOutputOptions
.
STRUCTURAL_TAG
,
params
.
structural_tag
raise
ValueError
(
"No valid structured output parameter found"
)
vllm/v1/structured_output/utils.py
View file @
4aed506b
...
...
@@ -47,7 +47,6 @@ def apply_grammar_bitmask(
scheduler_output
:
SchedulerOutput
,
input_batch
:
InputBatch
,
logits
:
torch
.
Tensor
,
device
:
torch
.
device
,
)
->
None
:
"""
Apply grammar bitmask to output logits of the model with xgrammar function.
...
...
@@ -56,7 +55,6 @@ def apply_grammar_bitmask(
scheduler_output (SchedulerOutput): The result of engine scheduling.
input_batch (InputBatch): The input of model runner.
logits (torch.Tensor): The output logits of model forward.
device (torch.device): The device that model runner running on.
"""
grammar_bitmask
=
scheduler_output
.
grammar_bitmask
if
grammar_bitmask
is
None
:
...
...
@@ -91,10 +89,7 @@ def apply_grammar_bitmask(
dtype
=
grammar_bitmask
.
dtype
,
)
cumulative_index
=
0
seq
=
sorted
(
scheduler_output
.
structured_output_request_ids
.
items
(),
key
=
lambda
x
:
x
[
1
]
)
for
req_id
,
_
in
seq
:
for
req_id
in
scheduler_output
.
structured_output_request_ids
:
num_spec_tokens
=
len
(
scheduler_output
.
scheduled_spec_decode_tokens
.
get
(
req_id
,
[])
)
...
...
@@ -117,7 +112,7 @@ def apply_grammar_bitmask(
xgr
.
apply_token_bitmask_inplace
(
logits
,
grammar_bitmask
.
to
(
device
,
non_blocking
=
True
),
grammar_bitmask
.
to
(
logits
.
device
,
non_blocking
=
True
),
indices
=
out_indices
if
not
skip_out_indices
else
None
,
)
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
4aed506b
...
...
@@ -2568,10 +2568,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
logits
=
model_output_broadcast_data
[
"logits"
]
# Apply structured output bitmasks if present
if
scheduler_output
.
grammar_bitmask
is
not
None
:
apply_grammar_bitmask
(
scheduler_output
,
self
.
input_batch
,
logits
,
self
.
device
)
if
scheduler_output
.
structured_output_request_ids
:
apply_grammar_bitmask
(
scheduler_output
,
self
.
input_batch
,
logits
)
with
record_function_or_nullcontext
(
"Sample"
):
sampler_output
=
self
.
_sample
(
logits
,
spec_decode_metadata
)
...
...
vllm/v1/worker/tpu_model_runner.py
View file @
4aed506b
...
...
@@ -1963,12 +1963,8 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self
.
grammar_bitmask_cpu
.
zero_
()
self
.
require_structured_out_cpu
.
zero_
()
sorted_struct_requests
=
sorted
(
scheduler_output
.
structured_output_request_ids
.
items
(),
key
=
lambda
item
:
item
[
1
],
)
cumulative_mask_idx
=
0
for
req_id
,
_
in
s
orted_struc
t_requests
:
for
req_id
in
s
cheduler_output
.
structured_outpu
t_request
_id
s
:
if
req_id
not
in
self
.
input_batch
.
req_id_to_index
:
continue
batch_index
=
self
.
input_batch
.
req_id_to_index
[
req_id
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment