Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ced2a92f
Unverified
Commit
ced2a92f
authored
Feb 12, 2026
by
Cyrus Leung
Committed by
GitHub
Feb 11, 2026
Browse files
[Refactor] Move validation to params definitions (#34362)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
e1d97c38
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
264 additions
and
245 deletions
+264
-245
vllm/pooling_params.py
vllm/pooling_params.py
+7
-10
vllm/sampling_params.py
vllm/sampling_params.py
+238
-0
vllm/v1/engine/input_processor.py
vllm/v1/engine/input_processor.py
+19
-235
No files found.
vllm/pooling_params.py
View file @
ced2a92f
...
@@ -72,7 +72,7 @@ class PoolingParams(
...
@@ -72,7 +72,7 @@ class PoolingParams(
"""Returns a deep copy of the PoolingParams instance."""
"""Returns a deep copy of the PoolingParams instance."""
return
deepcopy
(
self
)
return
deepcopy
(
self
)
def
verify
(
self
,
model_config
:
"
ModelConfig
"
)
->
None
:
def
verify
(
self
,
model_config
:
ModelConfig
)
->
None
:
# plugin task uses io_processor.parse_request to verify inputs,
# plugin task uses io_processor.parse_request to verify inputs,
# skipping PoolingParams verify
# skipping PoolingParams verify
if
self
.
task
==
"plugin"
:
if
self
.
task
==
"plugin"
:
...
@@ -87,12 +87,7 @@ class PoolingParams(
...
@@ -87,12 +87,7 @@ class PoolingParams(
self
.
_set_default_parameters
(
model_config
)
self
.
_set_default_parameters
(
model_config
)
self
.
_verify_valid_parameters
()
self
.
_verify_valid_parameters
()
def
_merge_default_parameters
(
def
_merge_default_parameters
(
self
,
model_config
:
ModelConfig
)
->
None
:
self
,
model_config
:
"ModelConfig | None"
=
None
)
->
None
:
if
model_config
is
None
:
return
pooler_config
=
model_config
.
pooler_config
pooler_config
=
model_config
.
pooler_config
if
pooler_config
is
None
:
if
pooler_config
is
None
:
return
return
...
@@ -119,7 +114,9 @@ class PoolingParams(
...
@@ -119,7 +114,9 @@ class PoolingParams(
self
.
_verify_step_pooling
(
pooler_config
,
valid_parameters
)
self
.
_verify_step_pooling
(
pooler_config
,
valid_parameters
)
def
_verify_step_pooling
(
def
_verify_step_pooling
(
self
,
pooler_config
:
"PoolerConfig"
,
valid_parameters
:
list
[
str
]
self
,
pooler_config
:
PoolerConfig
,
valid_parameters
:
list
[
str
],
):
):
step_pooling_parameters
=
[
"step_tag_id"
,
"returned_token_ids"
]
step_pooling_parameters
=
[
"step_tag_id"
,
"returned_token_ids"
]
if
pooler_config
.
tok_pooling_type
!=
"STEP"
:
if
pooler_config
.
tok_pooling_type
!=
"STEP"
:
...
@@ -142,12 +139,12 @@ class PoolingParams(
...
@@ -142,12 +139,12 @@ class PoolingParams(
if
getattr
(
self
,
k
,
None
)
is
None
:
if
getattr
(
self
,
k
,
None
)
is
None
:
setattr
(
self
,
k
,
getattr
(
pooler_config
,
k
))
setattr
(
self
,
k
,
getattr
(
pooler_config
,
k
))
def
_set_default_parameters
(
self
,
model_config
:
"
ModelConfig
| None"
):
def
_set_default_parameters
(
self
,
model_config
:
ModelConfig
):
if
self
.
task
in
[
"embed"
,
"token_embed"
]:
if
self
.
task
in
[
"embed"
,
"token_embed"
]:
if
self
.
use_activation
is
None
:
if
self
.
use_activation
is
None
:
self
.
use_activation
=
True
self
.
use_activation
=
True
if
self
.
dimensions
is
not
None
and
model_config
is
not
None
:
if
self
.
dimensions
is
not
None
:
if
not
model_config
.
is_matryoshka
:
if
not
model_config
.
is_matryoshka
:
raise
ValueError
(
raise
ValueError
(
f
'Model "
{
model_config
.
served_model_name
}
" does not '
f
'Model "
{
model_config
.
served_model_name
}
" does not '
...
...
vllm/sampling_params.py
View file @
ced2a92f
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
"""Sampling parameters for text generation."""
"""Sampling parameters for text generation."""
import
copy
import
copy
import
json
from
dataclasses
import
field
from
dataclasses
import
field
from
enum
import
Enum
,
IntEnum
from
enum
import
Enum
,
IntEnum
from
functools
import
cached_property
from
functools
import
cached_property
...
@@ -11,6 +12,7 @@ from typing import Annotated, Any
...
@@ -11,6 +12,7 @@ from typing import Annotated, Any
import
msgspec
import
msgspec
from
pydantic.dataclasses
import
dataclass
from
pydantic.dataclasses
import
dataclass
from
vllm.config
import
ModelConfig
,
SpeculativeConfig
,
StructuredOutputsConfig
from
vllm.exceptions
import
VLLMValidationError
from
vllm.exceptions
import
VLLMValidationError
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.logits_process
import
LogitsProcessor
from
vllm.logits_process
import
LogitsProcessor
...
@@ -453,6 +455,11 @@ class SamplingParams(
...
@@ -453,6 +455,11 @@ class SamplingParams(
parameter
=
"prompt_logprobs"
,
parameter
=
"prompt_logprobs"
,
value
=
self
.
prompt_logprobs
,
value
=
self
.
prompt_logprobs
,
)
)
if
self
.
logits_processors
:
# TODO: Remove `logits_processors` attribute
raise
ValueError
(
"vLLM V1 does not support per request user-provided logits processors."
)
if
self
.
truncate_prompt_tokens
is
not
None
and
(
if
self
.
truncate_prompt_tokens
is
not
None
and
(
self
.
truncate_prompt_tokens
==
0
or
self
.
truncate_prompt_tokens
<
-
1
self
.
truncate_prompt_tokens
==
0
or
self
.
truncate_prompt_tokens
<
-
1
):
):
...
@@ -589,6 +596,237 @@ class SamplingParams(
...
@@ -589,6 +596,237 @@ class SamplingParams(
)
)
return
copy
.
deepcopy
(
self
,
memo
=
logit_processor_refs
)
return
copy
.
deepcopy
(
self
,
memo
=
logit_processor_refs
)
def
verify
(
self
,
model_config
:
ModelConfig
,
speculative_config
:
SpeculativeConfig
|
None
,
structured_outputs_config
:
StructuredOutputsConfig
|
None
,
tokenizer
:
TokenizerLike
|
None
,
)
->
None
:
self
.
_validate_logprobs
(
model_config
)
self
.
_validate_logit_bias
(
model_config
)
self
.
_validate_allowed_token_ids
(
tokenizer
)
self
.
_validate_spec_decode
(
speculative_config
)
self
.
_validate_structured_outputs
(
structured_outputs_config
,
tokenizer
)
def
_validate_logprobs
(
self
,
model_config
:
ModelConfig
)
->
None
:
max_logprobs
=
model_config
.
max_logprobs
if
max_logprobs
==
-
1
:
max_logprobs
=
model_config
.
get_vocab_size
()
# Validate sample logprobs.
if
num_logprobs
:
=
self
.
logprobs
:
if
num_logprobs
==
-
1
:
num_logprobs
=
model_config
.
get_vocab_size
()
if
num_logprobs
>
max_logprobs
:
raise
VLLMValidationError
(
f
"Requested sample logprobs of
{
num_logprobs
}
, "
f
"which is greater than max allowed:
{
max_logprobs
}
"
,
parameter
=
"logprobs"
,
value
=
num_logprobs
,
)
# Validate prompt logprobs.
if
num_prompt_logprobs
:
=
self
.
prompt_logprobs
:
if
num_prompt_logprobs
==
-
1
:
num_prompt_logprobs
=
model_config
.
get_vocab_size
()
if
num_prompt_logprobs
>
max_logprobs
:
raise
VLLMValidationError
(
f
"Requested prompt logprobs of
{
num_prompt_logprobs
}
, "
f
"which is greater than max allowed:
{
max_logprobs
}
"
,
parameter
=
"prompt_logprobs"
,
value
=
num_prompt_logprobs
,
)
def
_validate_logit_bias
(
self
,
model_config
:
ModelConfig
)
->
None
:
"""Validate logit_bias token IDs are within vocabulary range."""
if
not
self
.
logit_bias
:
return
vocab_size
=
model_config
.
get_vocab_size
()
invalid_token_ids
=
[
token_id
for
token_id
in
self
.
logit_bias
if
token_id
<
0
or
token_id
>=
vocab_size
]
if
invalid_token_ids
:
raise
VLLMValidationError
(
f
"token_id(s)
{
invalid_token_ids
}
in logit_bias contain "
f
"out-of-vocab token ids. Vocabulary size:
{
vocab_size
}
"
,
parameter
=
"logit_bias"
,
value
=
invalid_token_ids
,
)
def
_validate_allowed_token_ids
(
self
,
tokenizer
:
TokenizerLike
|
None
)
->
None
:
allowed_token_ids
=
self
.
allowed_token_ids
if
allowed_token_ids
is
None
:
return
if
len
(
allowed_token_ids
)
==
0
:
raise
VLLMValidationError
(
"allowed_token_ids is not None and empty!"
,
parameter
=
"allowed_token_ids"
,
value
=
allowed_token_ids
,
)
if
tokenizer
is
not
None
:
vocab_size
=
len
(
tokenizer
)
invalid_token_ids
=
[
token_id
for
token_id
in
allowed_token_ids
if
token_id
<
0
or
token_id
>=
vocab_size
]
if
invalid_token_ids
:
raise
VLLMValidationError
(
"allowed_token_ids contains out-of-vocab token id!"
,
parameter
=
"allowed_token_ids"
,
value
=
invalid_token_ids
,
)
def
_validate_spec_decode
(
self
,
speculative_config
:
SpeculativeConfig
|
None
,
)
->
None
:
if
speculative_config
is
None
:
return
# Some sampling parameters are not yet compatible with spec decoding.
if
self
.
min_tokens
>
1
or
self
.
min_p
>
_SAMPLING_EPS
or
self
.
logit_bias
:
raise
ValueError
(
"The min_tokens, min_p, and logit_bias sampling parameters "
"are not yet supported with speculative decoding."
)
def
_validate_structured_outputs
(
self
,
structured_outputs_config
:
StructuredOutputsConfig
|
None
,
tokenizer
:
TokenizerLike
|
None
,
)
->
None
:
if
structured_outputs_config
is
None
or
self
.
structured_outputs
is
None
:
return
if
tokenizer
is
None
:
raise
ValueError
(
"Structured outputs requires a tokenizer so it can't be used with 'skip_tokenizer_init'"
# noqa: E501
)
backend
=
structured_outputs_config
.
backend
if
_backend
:
=
self
.
structured_outputs
.
_backend
:
# Request-level backend selection is not supported.
# The values may differ if `params` is reused and was set
# to a specific backend based on `auto` behavior in a previous
# request. We remember that it was set as a result of `auto`
# using the `_backend_was_auto` field set in the params.
if
backend
!=
_backend
and
not
(
backend
==
"auto"
and
self
.
structured_outputs
.
_backend_was_auto
):
raise
ValueError
(
"Request-level structured output backend selection is not "
f
"supported. The request specified '
{
_backend
}
', but vLLM "
f
"was initialised with '
{
backend
}
'. This error can be "
"resolved by removing '_backend' from the request."
)
else
:
self
.
structured_outputs
.
_backend
=
backend
# Request content validation
if
(
isinstance
(
self
.
structured_outputs
.
choice
,
list
)
and
not
self
.
structured_outputs
.
choice
):
# It is invalid for choice to be an empty list
raise
ValueError
(
f
"Choice '
{
self
.
structured_outputs
.
choice
}
' cannot be an empty list"
# noqa: E501
)
# Reject empty string grammar early to avoid engine-side crashes
if
(
isinstance
(
self
.
structured_outputs
.
grammar
,
str
)
and
self
.
structured_outputs
.
grammar
.
strip
()
==
""
):
raise
ValueError
(
"structured_outputs.grammar cannot be an empty string"
)
from
vllm.tokenizers.mistral
import
MistralTokenizer
from
vllm.v1.structured_output.backend_guidance
import
(
has_guidance_unsupported_json_features
,
validate_guidance_grammar
,
)
from
vllm.v1.structured_output.backend_lm_format_enforcer
import
(
validate_structured_output_request_lm_format_enforcer
,
)
from
vllm.v1.structured_output.backend_outlines
import
(
validate_structured_output_request_outlines
,
)
from
vllm.v1.structured_output.backend_xgrammar
import
validate_xgrammar_grammar
if
backend
.
startswith
(
"xgrammar"
):
# xgrammar with no fallback
validate_xgrammar_grammar
(
self
)
elif
backend
.
startswith
(
"guidance"
):
# TODO: ideally we would have the LLTokenizer here as Lark syntax
# allows <|special_token|> and similar, see
# https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens
# Without tokenizer these are disallowed in grammars.
if
isinstance
(
tokenizer
,
MistralTokenizer
):
raise
ValueError
(
"Mistral tokenizer is not supported for the 'guidance' "
"structured output backend. Please use ['xgrammar', 'outlines'] "
"backends or tokenizer_mode='hf' instead."
)
validate_guidance_grammar
(
self
,
tokenizer
=
None
)
elif
backend
==
"outlines"
:
# outlines backend
validate_structured_output_request_outlines
(
self
)
elif
backend
==
"lm-format-enforcer"
:
# lm format enforcer backend
if
isinstance
(
tokenizer
,
MistralTokenizer
):
raise
ValueError
(
"Mistral tokenizer is not supported for the 'lm-format-enforcer' "
"structured output backend. Please use ['xgrammar', 'outlines'] "
"backends or tokenizer_mode='hf' instead."
)
validate_structured_output_request_lm_format_enforcer
(
self
)
else
:
# NOTE: backend must be "auto" here, because we have
# checked supported_backends above.
# In this mode, we set opinionated defaults based on what we think
# will satisfy the most use cases without having to worry about
# this setting. We include fallback behavior here, but not with any
# other setting where a specific backend was specified.
try
:
validate_xgrammar_grammar
(
self
)
self
.
structured_outputs
.
_backend
=
"xgrammar"
except
ValueError
:
# The request either failed validation
# or includes some jsonschema feature(s) that
# are not supported in xgrammar.
# Check if schema has features unsupported by guidance
so_params
=
self
.
structured_outputs
skip_guidance
=
False
if
so_params
.
json
:
if
isinstance
(
so_params
.
json
,
str
):
schema
=
json
.
loads
(
so_params
.
json
)
else
:
schema
=
so_params
.
json
skip_guidance
=
has_guidance_unsupported_json_features
(
schema
)
if
isinstance
(
tokenizer
,
MistralTokenizer
)
or
skip_guidance
:
# Fall back to outlines if the tokenizer is Mistral
# or if schema contains features unsupported by guidance
validate_structured_output_request_outlines
(
self
)
self
.
structured_outputs
.
_backend
=
"outlines"
else
:
# Fall back to guidance by default.
validate_guidance_grammar
(
self
,
tokenizer
=
None
)
self
.
structured_outputs
.
_backend
=
"guidance"
# Remember that this backend was set automatically
self
.
structured_outputs
.
_backend_was_auto
=
True
# Run post-init validation. This is also important to ensure subsequent
# roundtrip serialization/deserialization won't fail.
self
.
structured_outputs
.
__post_init__
()
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
return
(
return
(
f
"SamplingParams(n=
{
self
.
n
}
, "
f
"SamplingParams(n=
{
self
.
n
}
, "
...
...
vllm/v1/engine/input_processor.py
View file @
ced2a92f
...
@@ -6,7 +6,6 @@ from collections.abc import Mapping
...
@@ -6,7 +6,6 @@ from collections.abc import Mapping
from
typing
import
Any
,
Literal
,
cast
from
typing
import
Any
,
Literal
,
cast
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.exceptions
import
VLLMValidationError
from
vllm.inputs.data
import
(
from
vllm.inputs.data
import
(
ProcessorInputs
,
ProcessorInputs
,
PromptType
,
PromptType
,
...
@@ -30,25 +29,13 @@ from vllm.multimodal.utils import argsort_mm_positions
...
@@ -30,25 +29,13 @@ from vllm.multimodal.utils import argsort_mm_positions
from
vllm.pooling_params
import
PoolingParams
from
vllm.pooling_params
import
PoolingParams
from
vllm.renderers
import
BaseRenderer
from
vllm.renderers
import
BaseRenderer
from
vllm.renderers.inputs
import
DictPrompt
,
TokPrompt
from
vllm.renderers.inputs
import
DictPrompt
,
TokPrompt
from
vllm.sampling_params
import
_SAMPLING_EPS
,
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.tasks
import
POOLING_TASKS
,
SupportedTask
from
vllm.tasks
import
POOLING_TASKS
,
SupportedTask
from
vllm.tokenizers
import
TokenizerLike
from
vllm.tokenizers
import
TokenizerLike
from
vllm.tokenizers.mistral
import
MistralTokenizer
from
vllm.utils
import
length_from_prompt_token_ids_or_embeds
,
random_uuid
from
vllm.utils
import
length_from_prompt_token_ids_or_embeds
,
random_uuid
from
vllm.utils.torch_utils
import
set_default_torch_num_threads
from
vllm.utils.torch_utils
import
set_default_torch_num_threads
from
vllm.v1.engine
import
EngineCoreRequest
from
vllm.v1.engine
import
EngineCoreRequest
from
vllm.v1.metrics.stats
import
MultiModalCacheStats
from
vllm.v1.metrics.stats
import
MultiModalCacheStats
from
vllm.v1.structured_output.backend_guidance
import
(
has_guidance_unsupported_json_features
,
validate_guidance_grammar
,
)
from
vllm.v1.structured_output.backend_lm_format_enforcer
import
(
validate_structured_output_request_lm_format_enforcer
,
)
from
vllm.v1.structured_output.backend_outlines
import
(
validate_structured_output_request_outlines
,
)
from
vllm.v1.structured_output.backend_xgrammar
import
validate_xgrammar_grammar
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -64,6 +51,7 @@ class InputProcessor:
...
@@ -64,6 +51,7 @@ class InputProcessor:
self
.
cache_config
=
vllm_config
.
cache_config
self
.
cache_config
=
vllm_config
.
cache_config
self
.
lora_config
=
vllm_config
.
lora_config
self
.
lora_config
=
vllm_config
.
lora_config
self
.
scheduler_config
=
vllm_config
.
scheduler_config
self
.
scheduler_config
=
vllm_config
.
scheduler_config
self
.
speculative_config
=
vllm_config
.
speculative_config
self
.
structured_outputs_config
=
vllm_config
.
structured_outputs_config
self
.
structured_outputs_config
=
vllm_config
.
structured_outputs_config
self
.
observability_config
=
vllm_config
.
observability_config
self
.
observability_config
=
vllm_config
.
observability_config
...
@@ -101,101 +89,6 @@ class InputProcessor:
...
@@ -101,101 +89,6 @@ class InputProcessor:
def
renderer
(
self
)
->
BaseRenderer
:
def
renderer
(
self
)
->
BaseRenderer
:
return
self
.
input_preprocessor
.
renderer
return
self
.
input_preprocessor
.
renderer
def
_validate_logprobs
(
self
,
params
:
SamplingParams
,
)
->
None
:
max_logprobs
=
self
.
model_config
.
max_logprobs
if
max_logprobs
==
-
1
:
max_logprobs
=
self
.
model_config
.
get_vocab_size
()
# Validate sample logprobs.
if
params
.
logprobs
:
num_logprobs
=
params
.
logprobs
if
num_logprobs
==
-
1
:
num_logprobs
=
self
.
model_config
.
get_vocab_size
()
if
num_logprobs
>
max_logprobs
:
raise
VLLMValidationError
(
f
"Requested sample logprobs of
{
num_logprobs
}
, "
f
"which is greater than max allowed:
{
max_logprobs
}
"
,
parameter
=
"logprobs"
,
value
=
num_logprobs
,
)
# Validate prompt logprobs.
if
params
.
prompt_logprobs
:
num_prompt_logprobs
=
params
.
prompt_logprobs
if
num_prompt_logprobs
==
-
1
:
num_prompt_logprobs
=
self
.
model_config
.
get_vocab_size
()
if
num_prompt_logprobs
>
max_logprobs
:
raise
VLLMValidationError
(
f
"Requested prompt logprobs of
{
num_prompt_logprobs
}
, "
f
"which is greater than max allowed:
{
max_logprobs
}
"
,
parameter
=
"prompt_logprobs"
,
value
=
num_prompt_logprobs
,
)
def
_validate_sampling_params
(
self
,
params
:
SamplingParams
,
)
->
None
:
self
.
_validate_structured_output
(
params
)
self
.
_validate_logit_bias
(
params
)
if
params
.
allowed_token_ids
is
None
:
return
if
not
params
.
allowed_token_ids
:
raise
ValueError
(
"allowed_token_ids is not None and empty!"
)
if
self
.
tokenizer
is
None
:
# When skip_tokenizer_init=True, we can't validate token IDs
# Skip validation and let the model handle invalid tokens
return
vocab_size
=
len
(
self
.
tokenizer
)
if
not
all
(
0
<=
tid
<
vocab_size
for
tid
in
params
.
allowed_token_ids
):
raise
ValueError
(
"allowed_token_ids contains out-of-vocab token id!"
)
def
_validate_logit_bias
(
self
,
params
:
SamplingParams
,
)
->
None
:
"""Validate logit_bias token IDs are within vocabulary range."""
if
not
params
.
logit_bias
:
return
vocab_size
=
self
.
model_config
.
get_vocab_size
()
invalid_token_ids
=
[]
for
token_id
in
params
.
logit_bias
:
if
token_id
<
0
or
token_id
>=
vocab_size
:
invalid_token_ids
.
append
(
token_id
)
if
invalid_token_ids
:
raise
VLLMValidationError
(
f
"token_id(s)
{
invalid_token_ids
}
in logit_bias contain "
f
"out-of-vocab token ids. Vocabulary size:
{
vocab_size
}
"
,
parameter
=
"logit_bias"
,
value
=
invalid_token_ids
,
)
def
_validate_supported_sampling_params
(
self
,
params
:
SamplingParams
,
)
->
None
:
# Logits processors not supported.
if
params
.
logits_processors
:
raise
ValueError
(
"vLLM V1 does not support per request user-provided logits processors."
)
# Some sampling parameters are not yet compatible with spec decoding.
if
self
.
vllm_config
.
speculative_config
is
not
None
and
(
params
.
min_tokens
>
1
or
params
.
min_p
>
_SAMPLING_EPS
or
params
.
logit_bias
):
raise
ValueError
(
"The min_tokens, min_p, and logit_bias sampling parameters "
"are not yet supported with speculative decoding."
)
def
_validate_params
(
def
_validate_params
(
self
,
self
,
params
:
SamplingParams
|
PoolingParams
,
params
:
SamplingParams
|
PoolingParams
,
...
@@ -203,11 +96,15 @@ class InputProcessor:
...
@@ -203,11 +96,15 @@ class InputProcessor:
# is passed to all `process_inputs` calls
# is passed to all `process_inputs` calls
supported_tasks
:
tuple
[
SupportedTask
,
...]
|
None
,
supported_tasks
:
tuple
[
SupportedTask
,
...]
|
None
,
):
):
"""
"""Raise `ValueError` if SamplingParams or PoolingParams is not valid."""
Validate supported SamplingParam.
if
isinstance
(
params
,
SamplingParams
):
Should raise ValueError if unsupported for API Server.
params
.
verify
(
"""
self
.
model_config
,
if
isinstance
(
params
,
PoolingParams
):
self
.
speculative_config
,
self
.
structured_outputs_config
,
self
.
tokenizer
,
)
elif
isinstance
(
params
,
PoolingParams
):
if
supported_tasks
is
None
:
if
supported_tasks
is
None
:
raise
RuntimeError
(
"`supported_tasks` must be passed for pooling"
)
raise
RuntimeError
(
"`supported_tasks` must be passed for pooling"
)
...
@@ -233,12 +130,11 @@ class InputProcessor:
...
@@ -233,12 +130,11 @@ class InputProcessor:
)
)
params
.
verify
(
self
.
model_config
)
params
.
verify
(
self
.
model_config
)
else
:
return
raise
TypeError
(
f
"params must be either SamplingParams or PoolingParams, "
self
.
_validate_logprobs
(
params
)
f
"but got
{
type
(
params
).
__name__
}
"
self
.
_validate_sampling_params
(
params
)
)
self
.
_validate_supported_sampling_params
(
params
)
def
_parse_mm_items
(
self
,
mm_data
:
MultiModalDataDict
)
->
MultiModalDataItems
:
def
_parse_mm_items
(
self
,
mm_data
:
MultiModalDataDict
)
->
MultiModalDataItems
:
mm_processor
=
self
.
input_preprocessor
.
_get_mm_processor
()
mm_processor
=
self
.
input_preprocessor
.
_get_mm_processor
()
...
@@ -334,120 +230,6 @@ class InputProcessor:
...
@@ -334,120 +230,6 @@ class InputProcessor:
"[lora_path]` to use the LoRA tokenizer."
"[lora_path]` to use the LoRA tokenizer."
)
)
def
_validate_structured_output
(
self
,
params
:
SamplingParams
)
->
None
:
if
not
params
.
structured_outputs
or
not
self
.
structured_outputs_config
:
return
if
self
.
model_config
.
skip_tokenizer_init
and
params
.
structured_outputs
:
raise
ValueError
(
"Structured outputs requires a tokenizer so it can't be used with 'skip_tokenizer_init'"
# noqa: E501
)
backend
=
self
.
structured_outputs_config
.
backend
if
_backend
:
=
params
.
structured_outputs
.
_backend
:
# Request-level backend selection is not supported.
# The values may differ if `params` is reused and was set
# to a specific backend based on `auto` behavior in a previous
# request. We remember that it was set as a result of `auto`
# using the `_backend_was_auto` field set in the params.
if
backend
!=
_backend
and
not
(
backend
==
"auto"
and
params
.
structured_outputs
.
_backend_was_auto
):
raise
ValueError
(
"Request-level structured output backend selection is not "
f
"supported. The request specified '
{
_backend
}
', but vLLM "
f
"was initialised with '
{
backend
}
'. This error can be "
"resolved by removing '_backend' from the request."
)
else
:
params
.
structured_outputs
.
_backend
=
backend
# Request content validation
if
(
isinstance
(
params
.
structured_outputs
.
choice
,
list
)
and
not
params
.
structured_outputs
.
choice
):
# It is invalid for choice to be an empty list
raise
ValueError
(
f
"Choice '
{
params
.
structured_outputs
.
choice
}
' cannot be an empty list"
# noqa: E501
)
# Reject empty string grammar early to avoid engine-side crashes
if
(
isinstance
(
params
.
structured_outputs
.
grammar
,
str
)
and
params
.
structured_outputs
.
grammar
.
strip
()
==
""
):
raise
ValueError
(
"structured_outputs.grammar cannot be an empty string"
)
if
backend
.
startswith
(
"xgrammar"
):
# xgrammar with no fallback
validate_xgrammar_grammar
(
params
)
elif
backend
.
startswith
(
"guidance"
):
# TODO: ideally we would have the LLTokenizer here as Lark syntax
# allows <|special_token|> and similar, see
# https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens
# Without tokenizer these are disallowed in grammars.
if
isinstance
(
self
.
tokenizer
,
MistralTokenizer
):
raise
ValueError
(
"Mistral tokenizer is not supported for the 'guidance' "
"structured output backend. Please use ['xgrammar', 'outlines'] "
"backends or tokenizer_mode='hf' instead."
)
validate_guidance_grammar
(
params
,
tokenizer
=
None
)
elif
backend
==
"outlines"
:
# outlines backend
validate_structured_output_request_outlines
(
params
)
elif
backend
==
"lm-format-enforcer"
:
# lm format enforcer backend
if
isinstance
(
self
.
tokenizer
,
MistralTokenizer
):
raise
ValueError
(
"Mistral tokenizer is not supported for the 'lm-format-enforcer' "
"structured output backend. Please use ['xgrammar', 'outlines'] "
"backends or tokenizer_mode='hf' instead."
)
validate_structured_output_request_lm_format_enforcer
(
params
)
else
:
# NOTE: backend must be "auto" here, because we have
# checked supported_backends above.
# In this mode, we set opinionated defaults based on what we think
# will satisfy the most use cases without having to worry about
# this setting. We include fallback behavior here, but not with any
# other setting where a specific backend was specified.
try
:
validate_xgrammar_grammar
(
params
)
params
.
structured_outputs
.
_backend
=
"xgrammar"
except
ValueError
:
# The request either failed validation
# or includes some jsonschema feature(s) that
# are not supported in xgrammar.
# Check if schema has features unsupported by guidance
so_params
=
params
.
structured_outputs
skip_guidance
=
False
if
so_params
.
json
:
if
isinstance
(
so_params
.
json
,
str
):
import
json
schema
=
json
.
loads
(
so_params
.
json
)
else
:
schema
=
so_params
.
json
skip_guidance
=
has_guidance_unsupported_json_features
(
schema
)
if
isinstance
(
self
.
tokenizer
,
MistralTokenizer
)
or
skip_guidance
:
# Fall back to outlines if the tokenizer is Mistral
# or if schema contains features unsupported by guidance
validate_structured_output_request_outlines
(
params
)
params
.
structured_outputs
.
_backend
=
"outlines"
else
:
# Fall back to guidance by default.
validate_guidance_grammar
(
params
,
tokenizer
=
None
)
params
.
structured_outputs
.
_backend
=
"guidance"
# Remember that this backend was set automatically
params
.
structured_outputs
.
_backend_was_auto
=
True
# Run post-init validation. This is also important to ensure subsequent
# roundtrip serialization/deserialization won't fail.
params
.
structured_outputs
.
__post_init__
()
def
_extract_singleton_mm_data
(
def
_extract_singleton_mm_data
(
self
,
prompt
:
SingletonPrompt
self
,
prompt
:
SingletonPrompt
)
->
MultiModalDataDict
|
None
:
)
->
MultiModalDataDict
|
None
:
...
@@ -618,8 +400,10 @@ class InputProcessor:
...
@@ -618,8 +400,10 @@ class InputProcessor:
prompt_token_ids
,
prompt_embeds
prompt_token_ids
,
prompt_embeds
)
)
sampling_params
.
max_tokens
=
self
.
model_config
.
max_model_len
-
seq_len
sampling_params
.
max_tokens
=
self
.
model_config
.
max_model_len
-
seq_len
sampling_params
.
update_from_generation_config
(
sampling_params
.
update_from_generation_config
(
self
.
generation_config_fields
,
eos_token_id
self
.
generation_config_fields
,
None
if
self
.
tokenizer
is
None
else
self
.
tokenizer
.
eos_token_id
,
)
)
if
self
.
tokenizer
is
not
None
:
if
self
.
tokenizer
is
not
None
:
sampling_params
.
update_from_tokenizer
(
self
.
tokenizer
)
sampling_params
.
update_from_tokenizer
(
self
.
tokenizer
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment