Unverified Commit fb455ed5 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[V0 Deprecation] Remove code related to per-request logits processors (#34400)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent f5897613
...@@ -45,7 +45,6 @@ class MockModelConfig: ...@@ -45,7 +45,6 @@ class MockModelConfig:
multimodal_config = MultiModalConfig() multimodal_config = MultiModalConfig()
hf_config = MockHFConfig() hf_config = MockHFConfig()
hf_text_config = MockHFConfig() hf_text_config = MockHFConfig()
logits_processor_pattern = None
logits_processors: list[str] | None = None logits_processors: list[str] | None = None
diff_sampling_param: dict | None = None diff_sampling_param: dict | None = None
allowed_local_media_path: str = "" allowed_local_media_path: str = ""
......
...@@ -44,7 +44,6 @@ class MockModelConfig: ...@@ -44,7 +44,6 @@ class MockModelConfig:
tokenizer_revision = None tokenizer_revision = None
multimodal_config = MultiModalConfig() multimodal_config = MultiModalConfig()
hf_config = MockHFConfig() hf_config = MockHFConfig()
logits_processor_pattern = None
logits_processors: list[str] | None = None logits_processors: list[str] | None = None
diff_sampling_param: dict | None = None diff_sampling_param: dict | None = None
allowed_local_media_path: str = "" allowed_local_media_path: str = ""
......
...@@ -45,7 +45,6 @@ class MockModelConfig: ...@@ -45,7 +45,6 @@ class MockModelConfig:
multimodal_config: MultiModalConfig = field(default_factory=MultiModalConfig) multimodal_config: MultiModalConfig = field(default_factory=MultiModalConfig)
hf_config: MockHFConfig = field(default_factory=MockHFConfig) hf_config: MockHFConfig = field(default_factory=MockHFConfig)
logits_processors: list[str] | None = None logits_processors: list[str] | None = None
logits_processor_pattern: str | None = None
diff_sampling_param: dict | None = None diff_sampling_param: dict | None = None
allowed_local_media_path: str = "" allowed_local_media_path: str = ""
allowed_media_domains: list[str] | None = None allowed_media_domains: list[str] | None = None
......
...@@ -521,7 +521,6 @@ class MockModelConfig: ...@@ -521,7 +521,6 @@ class MockModelConfig:
hf_config = MockHFConfig() hf_config = MockHFConfig()
hf_text_config = MockHFConfig() hf_text_config = MockHFConfig()
logits_processors: list[str] | None = None logits_processors: list[str] | None = None
logits_processor_pattern = None
diff_sampling_param: dict | None = None diff_sampling_param: dict | None = None
allowed_local_media_path: str = "" allowed_local_media_path: str = ""
allowed_media_domains: list[str] | None = None allowed_media_domains: list[str] | None = None
......
...@@ -144,20 +144,6 @@ def test_bad_words(llm): ...@@ -144,20 +144,6 @@ def test_bad_words(llm):
assert not contains_bad_word(new_text, new_tokens, bad_words_2) assert not contains_bad_word(new_text, new_tokens, bad_words_2)
def test_logits_processor(llm):
"""Check that we reject logits processor."""
# This sample logits processor gives infinite score to the i-th token,
# where i is the length of the input sequence.
# We therefore expect the output token sequence to be [0, 1, 2, ...]
def pick_ith(token_ids, logits):
logits[len(token_ids)] = float("inf")
return logits
with pytest.raises(ValueError):
_ = llm.generate(PROMPT, SamplingParams(logits_processors=[pick_ith]))
def test_allowed_token_ids(llm): def test_allowed_token_ids(llm):
"""Check that we can use allowed_token_ids.""" """Check that we can use allowed_token_ids."""
......
...@@ -252,10 +252,6 @@ class ModelConfig: ...@@ -252,10 +252,6 @@ class ModelConfig:
hf_overrides: HfOverrides = field(default_factory=dict) hf_overrides: HfOverrides = field(default_factory=dict)
"""If a dictionary, contains arguments to be forwarded to the Hugging Face """If a dictionary, contains arguments to be forwarded to the Hugging Face
config. If a callable, it is called to update the HuggingFace config.""" config. If a callable, it is called to update the HuggingFace config."""
logits_processor_pattern: str | None = None
"""Optional regex pattern specifying valid logits processor qualified names
that can be passed with the `logits_processors` extra completion argument.
Defaults to `None`, which allows no processors."""
generation_config: str = "auto" generation_config: str = "auto"
"""The folder path to the generation config. Defaults to `"auto"`, the """The folder path to the generation config. Defaults to `"auto"`, the
generation config will be loaded from model path. If set to `"vllm"`, no generation config will be loaded from model path. If set to `"vllm"`, no
...@@ -342,7 +338,6 @@ class ModelConfig: ...@@ -342,7 +338,6 @@ class ModelConfig:
"config_format", "config_format",
"hf_token", "hf_token",
"hf_overrides", "hf_overrides",
"logits_processor_pattern",
"override_attention_dtype", "override_attention_dtype",
"logits_processors", "logits_processors",
"io_processor_plugin", "io_processor_plugin",
......
...@@ -508,8 +508,6 @@ class EngineArgs: ...@@ -508,8 +508,6 @@ class EngineArgs:
reasoning_parser: str = StructuredOutputsConfig.reasoning_parser reasoning_parser: str = StructuredOutputsConfig.reasoning_parser
reasoning_parser_plugin: str | None = None reasoning_parser_plugin: str | None = None
logits_processor_pattern: str | None = ModelConfig.logits_processor_pattern
speculative_config: dict[str, Any] | None = None speculative_config: dict[str, Any] | None = None
show_hidden_metrics_for_version: str | None = ( show_hidden_metrics_for_version: str | None = (
...@@ -710,9 +708,6 @@ class EngineArgs: ...@@ -710,9 +708,6 @@ class EngineArgs:
) )
model_group.add_argument("--hf-overrides", **model_kwargs["hf_overrides"]) model_group.add_argument("--hf-overrides", **model_kwargs["hf_overrides"])
model_group.add_argument("--pooler-config", **model_kwargs["pooler_config"]) model_group.add_argument("--pooler-config", **model_kwargs["pooler_config"])
model_group.add_argument(
"--logits-processor-pattern", **model_kwargs["logits_processor_pattern"]
)
model_group.add_argument( model_group.add_argument(
"--generation-config", **model_kwargs["generation_config"] "--generation-config", **model_kwargs["generation_config"]
) )
...@@ -1320,7 +1315,6 @@ class EngineArgs: ...@@ -1320,7 +1315,6 @@ class EngineArgs:
mm_encoder_tp_mode=self.mm_encoder_tp_mode, mm_encoder_tp_mode=self.mm_encoder_tp_mode,
mm_encoder_attn_backend=self.mm_encoder_attn_backend, mm_encoder_attn_backend=self.mm_encoder_attn_backend,
pooler_config=self.pooler_config, pooler_config=self.pooler_config,
logits_processor_pattern=self.logits_processor_pattern,
generation_config=self.generation_config, generation_config=self.generation_config,
override_generation_config=self.override_generation_config, override_generation_config=self.override_generation_config,
enable_sleep_mode=self.enable_sleep_mode, enable_sleep_mode=self.enable_sleep_mode,
...@@ -1429,7 +1423,7 @@ class EngineArgs: ...@@ -1429,7 +1423,7 @@ class EngineArgs:
self.model_weights = model_config.model_weights self.model_weights = model_config.model_weights
self.tokenizer = model_config.tokenizer self.tokenizer = model_config.tokenizer
self._check_feature_supported(model_config) self._check_feature_supported()
self._set_default_chunked_prefill_and_prefix_caching_args(model_config) self._set_default_chunked_prefill_and_prefix_caching_args(model_config)
self._set_default_max_num_seqs_and_batched_tokens_args( self._set_default_max_num_seqs_and_batched_tokens_args(
usage_context, model_config usage_context, model_config
...@@ -1831,11 +1825,8 @@ class EngineArgs: ...@@ -1831,11 +1825,8 @@ class EngineArgs:
return config return config
def _check_feature_supported(self, model_config: ModelConfig): def _check_feature_supported(self):
"""Raise an error if the feature is not supported.""" """Raise an error if the feature is not supported."""
if self.logits_processor_pattern != EngineArgs.logits_processor_pattern:
_raise_unsupported_error(feature_name="--logits-processor-pattern")
# No Concurrent Partial Prefills so far. # No Concurrent Partial Prefills so far.
if ( if (
self.max_num_partial_prefills != SchedulerConfig.max_num_partial_prefills self.max_num_partial_prefills != SchedulerConfig.max_num_partial_prefills
......
...@@ -26,13 +26,11 @@ from vllm.entrypoints.openai.engine.protocol import ( ...@@ -26,13 +26,11 @@ from vllm.entrypoints.openai.engine.protocol import (
FunctionCall, FunctionCall,
FunctionDefinition, FunctionDefinition,
LegacyStructuralTagResponseFormat, LegacyStructuralTagResponseFormat,
LogitsProcessors,
OpenAIBaseModel, OpenAIBaseModel,
StreamOptions, StreamOptions,
StructuralTagResponseFormat, StructuralTagResponseFormat,
ToolCall, ToolCall,
UsageInfo, UsageInfo,
get_logits_processors,
) )
from vllm.exceptions import VLLMValidationError from vllm.exceptions import VLLMValidationError
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -293,19 +291,7 @@ class ChatCompletionRequest(OpenAIBaseModel): ...@@ -293,19 +291,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
"through out the inference process and return in response." "through out the inference process and return in response."
), ),
) )
logits_processors: LogitsProcessors | None = Field(
default=None,
description=(
"A list of either qualified names of logits processors, or "
"constructor objects, to apply when sampling. A constructor is "
"a JSON object with a required 'qualname' field specifying the "
"qualified name of the processor class/factory, and optional "
"'args' and 'kwargs' fields containing positional and keyword "
"arguments. For example: {'qualname': "
"'my_module.MyLogitsProcessor', 'args': [1, 2], 'kwargs': "
"{'param': 'value'}}."
),
)
return_tokens_as_token_ids: bool | None = Field( return_tokens_as_token_ids: bool | None = Field(
default=None, default=None,
description=( description=(
...@@ -324,6 +310,7 @@ class ChatCompletionRequest(OpenAIBaseModel): ...@@ -324,6 +310,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
"need to map generated text back to input tokens." "need to map generated text back to input tokens."
), ),
) )
cache_salt: str | None = Field( cache_salt: str | None = Field(
default=None, default=None,
description=( description=(
...@@ -335,6 +322,7 @@ class ChatCompletionRequest(OpenAIBaseModel): ...@@ -335,6 +322,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
"to 256 bit)." "to 256 bit)."
), ),
) )
kv_transfer_params: dict[str, Any] | None = Field( kv_transfer_params: dict[str, Any] | None = Field(
default=None, default=None,
description="KVTransfer parameters used for disaggregated serving.", description="KVTransfer parameters used for disaggregated serving.",
...@@ -417,7 +405,6 @@ class ChatCompletionRequest(OpenAIBaseModel): ...@@ -417,7 +405,6 @@ class ChatCompletionRequest(OpenAIBaseModel):
def to_sampling_params( def to_sampling_params(
self, self,
max_tokens: int, max_tokens: int,
logits_processor_pattern: str | None,
default_sampling_params: dict, default_sampling_params: dict,
) -> SamplingParams: ) -> SamplingParams:
# Default parameters # Default parameters
...@@ -502,9 +489,6 @@ class ChatCompletionRequest(OpenAIBaseModel): ...@@ -502,9 +489,6 @@ class ChatCompletionRequest(OpenAIBaseModel):
min_tokens=self.min_tokens, min_tokens=self.min_tokens,
skip_special_tokens=self.skip_special_tokens, skip_special_tokens=self.skip_special_tokens,
spaces_between_special_tokens=self.spaces_between_special_tokens, spaces_between_special_tokens=self.spaces_between_special_tokens,
logits_processors=get_logits_processors(
self.logits_processors, logits_processor_pattern
),
include_stop_str_in_output=self.include_stop_str_in_output, include_stop_str_in_output=self.include_stop_str_in_output,
truncate_prompt_tokens=self.truncate_prompt_tokens, truncate_prompt_tokens=self.truncate_prompt_tokens,
output_kind=RequestOutputKind.DELTA output_kind=RequestOutputKind.DELTA
......
...@@ -86,7 +86,6 @@ from vllm.tool_parsers import ToolParser ...@@ -86,7 +86,6 @@ from vllm.tool_parsers import ToolParser
from vllm.tool_parsers.mistral_tool_parser import MistralToolCall from vllm.tool_parsers.mistral_tool_parser import MistralToolCall
from vllm.tool_parsers.utils import partial_json_loads from vllm.tool_parsers.utils import partial_json_loads
from vllm.utils.collection_utils import as_list from vllm.utils.collection_utils import as_list
from vllm.v1.sample.logits_processor import validate_logits_processors_parameters
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -130,9 +129,6 @@ class OpenAIServingChat(OpenAIServing): ...@@ -130,9 +129,6 @@ class OpenAIServingChat(OpenAIServing):
self.enable_log_outputs = enable_log_outputs self.enable_log_outputs = enable_log_outputs
self.enable_log_deltas = enable_log_deltas self.enable_log_deltas = enable_log_deltas
# set up logits processors
self.logits_processors = self.model_config.logits_processors
# set up reasoning parser # set up reasoning parser
self.reasoning_parser_cls = ParserManager.get_reasoning_parser( self.reasoning_parser_cls = ParserManager.get_reasoning_parser(
reasoning_parser_name=reasoning_parser reasoning_parser_name=reasoning_parser
...@@ -403,13 +399,8 @@ class OpenAIServingChat(OpenAIServing): ...@@ -403,13 +399,8 @@ class OpenAIServingChat(OpenAIServing):
else: else:
sampling_params = request.to_sampling_params( sampling_params = request.to_sampling_params(
max_tokens, max_tokens,
self.model_config.logits_processor_pattern,
self.default_sampling_params, self.default_sampling_params,
) )
validate_logits_processors_parameters(
self.logits_processors,
sampling_params,
)
self._log_inputs( self._log_inputs(
sub_request_id, sub_request_id,
......
...@@ -15,12 +15,10 @@ from vllm.config import ModelConfig ...@@ -15,12 +15,10 @@ from vllm.config import ModelConfig
from vllm.entrypoints.openai.engine.protocol import ( from vllm.entrypoints.openai.engine.protocol import (
AnyResponseFormat, AnyResponseFormat,
LegacyStructuralTagResponseFormat, LegacyStructuralTagResponseFormat,
LogitsProcessors,
OpenAIBaseModel, OpenAIBaseModel,
StreamOptions, StreamOptions,
StructuralTagResponseFormat, StructuralTagResponseFormat,
UsageInfo, UsageInfo,
get_logits_processors,
) )
from vllm.exceptions import VLLMValidationError from vllm.exceptions import VLLMValidationError
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -117,19 +115,6 @@ class CompletionRequest(OpenAIBaseModel): ...@@ -117,19 +115,6 @@ class CompletionRequest(OpenAIBaseModel):
"through out the inference process and return in response." "through out the inference process and return in response."
), ),
) )
logits_processors: LogitsProcessors | None = Field(
default=None,
description=(
"A list of either qualified names of logits processors, or "
"constructor objects, to apply when sampling. A constructor is "
"a JSON object with a required 'qualname' field specifying the "
"qualified name of the processor class/factory, and optional "
"'args' and 'kwargs' fields containing positional and keyword "
"arguments. For example: {'qualname': "
"'my_module.MyLogitsProcessor', 'args': [1, 2], 'kwargs': "
"{'param': 'value'}}."
),
)
return_tokens_as_token_ids: bool | None = Field( return_tokens_as_token_ids: bool | None = Field(
default=None, default=None,
...@@ -221,7 +206,6 @@ class CompletionRequest(OpenAIBaseModel): ...@@ -221,7 +206,6 @@ class CompletionRequest(OpenAIBaseModel):
def to_sampling_params( def to_sampling_params(
self, self,
max_tokens: int, max_tokens: int,
logits_processor_pattern: str | None,
default_sampling_params: dict | None = None, default_sampling_params: dict | None = None,
) -> SamplingParams: ) -> SamplingParams:
if default_sampling_params is None: if default_sampling_params is None:
...@@ -312,9 +296,6 @@ class CompletionRequest(OpenAIBaseModel): ...@@ -312,9 +296,6 @@ class CompletionRequest(OpenAIBaseModel):
skip_special_tokens=self.skip_special_tokens, skip_special_tokens=self.skip_special_tokens,
spaces_between_special_tokens=self.spaces_between_special_tokens, spaces_between_special_tokens=self.spaces_between_special_tokens,
include_stop_str_in_output=self.include_stop_str_in_output, include_stop_str_in_output=self.include_stop_str_in_output,
logits_processors=get_logits_processors(
self.logits_processors, logits_processor_pattern
),
truncate_prompt_tokens=self.truncate_prompt_tokens, truncate_prompt_tokens=self.truncate_prompt_tokens,
output_kind=RequestOutputKind.DELTA output_kind=RequestOutputKind.DELTA
if self.stream if self.stream
......
...@@ -42,7 +42,6 @@ from vllm.sampling_params import BeamSearchParams, SamplingParams ...@@ -42,7 +42,6 @@ from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.utils.async_utils import merge_async_iterators from vllm.utils.async_utils import merge_async_iterators
from vllm.utils.collection_utils import as_list from vllm.utils.collection_utils import as_list
from vllm.v1.sample.logits_processor import validate_logits_processors_parameters
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -67,9 +66,6 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -67,9 +66,6 @@ class OpenAIServingCompletion(OpenAIServing):
log_error_stack=log_error_stack, log_error_stack=log_error_stack,
) )
# set up logits processors
self.logits_processors = self.model_config.logits_processors
self.enable_prompt_tokens_details = enable_prompt_tokens_details self.enable_prompt_tokens_details = enable_prompt_tokens_details
self.enable_force_include_usage = enable_force_include_usage self.enable_force_include_usage = enable_force_include_usage
...@@ -178,13 +174,8 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -178,13 +174,8 @@ class OpenAIServingCompletion(OpenAIServing):
else: else:
sampling_params = request.to_sampling_params( sampling_params = request.to_sampling_params(
max_tokens, max_tokens,
self.model_config.logits_processor_pattern,
self.default_sampling_params, self.default_sampling_params,
) )
validate_logits_processors_parameters(
self.logits_processors,
sampling_params,
)
request_id_item = f"{request_id}-{i}" request_id_item = f"{request_id}-{i}"
......
...@@ -15,7 +15,6 @@ from pydantic.dataclasses import dataclass ...@@ -15,7 +15,6 @@ from pydantic.dataclasses import dataclass
from vllm.config import ModelConfig, SpeculativeConfig, StructuredOutputsConfig from vllm.config import ModelConfig, SpeculativeConfig, StructuredOutputsConfig
from vllm.exceptions import VLLMValidationError from vllm.exceptions import VLLMValidationError
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logits_process import LogitsProcessor
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.v1.serial_utils import PydanticMsgspecMixin from vllm.v1.serial_utils import PydanticMsgspecMixin
...@@ -207,11 +206,6 @@ class SamplingParams( ...@@ -207,11 +206,6 @@ class SamplingParams(
"""Whether to skip special tokens in the output.""" """Whether to skip special tokens in the output."""
spaces_between_special_tokens: bool = True spaces_between_special_tokens: bool = True
"""Whether to add spaces between special tokens in the output.""" """Whether to add spaces between special tokens in the output."""
# `list[LogitsProcessor] | None` type. We use Any here because
# `list[LogitsProcessor] | None` type is not supported by msgspec.
logits_processors: Any | None = None
"""Functions that modify logits based on previously generated tokens, and
optionally prompt tokens as a first argument."""
include_stop_str_in_output: bool = False include_stop_str_in_output: bool = False
"""Whether to include the stop strings in output text.""" """Whether to include the stop strings in output text."""
truncate_prompt_tokens: Annotated[int, msgspec.Meta(ge=-1)] | None = None truncate_prompt_tokens: Annotated[int, msgspec.Meta(ge=-1)] | None = None
...@@ -277,7 +271,6 @@ class SamplingParams( ...@@ -277,7 +271,6 @@ class SamplingParams(
detokenize: bool = True, detokenize: bool = True,
skip_special_tokens: bool = True, skip_special_tokens: bool = True,
spaces_between_special_tokens: bool = True, spaces_between_special_tokens: bool = True,
logits_processors: list[LogitsProcessor] | None = None,
truncate_prompt_tokens: Annotated[int, msgspec.Meta(ge=-1)] | None = None, truncate_prompt_tokens: Annotated[int, msgspec.Meta(ge=-1)] | None = None,
output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE, output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE,
structured_outputs: StructuredOutputsParams | None = None, structured_outputs: StructuredOutputsParams | None = None,
...@@ -318,7 +311,6 @@ class SamplingParams( ...@@ -318,7 +311,6 @@ class SamplingParams(
detokenize=detokenize, detokenize=detokenize,
skip_special_tokens=skip_special_tokens, skip_special_tokens=skip_special_tokens,
spaces_between_special_tokens=spaces_between_special_tokens, spaces_between_special_tokens=spaces_between_special_tokens,
logits_processors=logits_processors,
truncate_prompt_tokens=truncate_prompt_tokens, truncate_prompt_tokens=truncate_prompt_tokens,
output_kind=output_kind, output_kind=output_kind,
structured_outputs=structured_outputs, structured_outputs=structured_outputs,
...@@ -455,11 +447,6 @@ class SamplingParams( ...@@ -455,11 +447,6 @@ class SamplingParams(
parameter="prompt_logprobs", parameter="prompt_logprobs",
value=self.prompt_logprobs, value=self.prompt_logprobs,
) )
if self.logits_processors:
# TODO: Remove `logits_processors` attribute
raise ValueError(
"vLLM V1 does not support per request user-provided logits processors."
)
if self.truncate_prompt_tokens is not None and ( if self.truncate_prompt_tokens is not None and (
self.truncate_prompt_tokens == 0 or self.truncate_prompt_tokens < -1 self.truncate_prompt_tokens == 0 or self.truncate_prompt_tokens < -1
): ):
...@@ -573,28 +560,11 @@ class SamplingParams( ...@@ -573,28 +560,11 @@ class SamplingParams(
return self._bad_words_token_ids return self._bad_words_token_ids
def clone(self) -> "SamplingParams": def clone(self) -> "SamplingParams":
"""Deep copy, but maybe not the LogitsProcessor objects. """If skip_clone is True, uses shallow copy instead of deep copy."""
LogitsProcessor objects may contain an arbitrary, nontrivial amount of
data that is expensive to copy. However, if not copied, the processor
needs to support parallel decoding for multiple sequences
See https://github.com/vllm-project/vllm/issues/3087
If skip_clone is True, uses shallow copy instead of deep copy.
"""
if self.skip_clone: if self.skip_clone:
return copy.copy(self) return copy.copy(self)
logit_processor_refs = ( return copy.deepcopy(self)
None
if self.logits_processors is None
else {
id(lp): lp.clone() if hasattr(lp, "clone") else lp
for lp in self.logits_processors
}
)
return copy.deepcopy(self, memo=logit_processor_refs)
def verify( def verify(
self, self,
...@@ -605,6 +575,7 @@ class SamplingParams( ...@@ -605,6 +575,7 @@ class SamplingParams(
) -> None: ) -> None:
self._validate_logprobs(model_config) self._validate_logprobs(model_config)
self._validate_logit_bias(model_config) self._validate_logit_bias(model_config)
self._validate_logits_processors(model_config)
self._validate_allowed_token_ids(tokenizer) self._validate_allowed_token_ids(tokenizer)
self._validate_spec_decode(speculative_config) self._validate_spec_decode(speculative_config)
self._validate_structured_outputs(structured_outputs_config, tokenizer) self._validate_structured_outputs(structured_outputs_config, tokenizer)
...@@ -658,6 +629,13 @@ class SamplingParams( ...@@ -658,6 +629,13 @@ class SamplingParams(
value=invalid_token_ids, value=invalid_token_ids,
) )
def _validate_logits_processors(self, model_config: ModelConfig) -> None:
from vllm.v1.sample.logits_processor import (
validate_logits_processors_parameters,
)
validate_logits_processors_parameters(model_config.logits_processors, self)
def _validate_allowed_token_ids(self, tokenizer: TokenizerLike | None) -> None: def _validate_allowed_token_ids(self, tokenizer: TokenizerLike | None) -> None:
allowed_token_ids = self.allowed_token_ids allowed_token_ids = self.allowed_token_ids
if allowed_token_ids is None: if allowed_token_ids is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment