Unverified Commit 3b312fb7 authored by Nick Hill's avatar Nick Hill Committed by GitHub
Browse files

[Minor] Various small code cleanups/simplifications (#31508)


Signed-off-by: default avatarnjhill <nickhill123@gmail.com>
parent f84bf7d7
......@@ -1579,14 +1579,14 @@ class ModelConfig:
@property
def is_hybrid(self) -> bool:
if not self._model_info.is_hybrid:
return False
# Handle granite-4.0-micro case which uses hybrid config but does not
# actually contain any non-attention layers.
layer_types = getattr(self.hf_config, "layer_types", None)
if layer_types is not None and all(
return layer_types is None or not all(
layer == "attention" for layer in layer_types
):
return False
return self._model_info.is_hybrid
)
@property
def has_noops(self) -> bool:
......
......@@ -2005,7 +2005,6 @@ class OpenAIServingResponses(OpenAIServing):
return event
async with AsyncExitStack() as exit_stack:
processer = None
if self.use_harmony:
# TODO: in streaming, we noticed this bug:
# https://github.com/vllm-project/vllm/issues/25697
......
......@@ -44,11 +44,8 @@ class RenderConfig:
def verify_truncate_prompt_tokens(self, model_config: ModelConfig) -> int | None:
"""Validate and normalize `truncate_prompt_tokens` parameter."""
truncate_prompt_tokens = self.truncate_prompt_tokens
if truncate_prompt_tokens is None:
return None
if truncate_prompt_tokens == 0:
return 0
if truncate_prompt_tokens is None or truncate_prompt_tokens == 0:
return truncate_prompt_tokens
if truncate_prompt_tokens < 0:
truncate_prompt_tokens = model_config.max_model_len
......
......@@ -686,11 +686,7 @@ class InputPreprocessor:
mm_uuids: MultiModalUUIDDict | None = None,
) -> ProcessorInputs:
"""Preprocess the input prompt."""
res = self._preprocess(
prompt,
tokenization_kwargs,
mm_uuids=mm_uuids,
)
res = self._preprocess(prompt, tokenization_kwargs, mm_uuids=mm_uuids)
if self.mm_processor_cache and self.mm_cache_stats is not None:
delta = self.mm_processor_cache.make_stats(delta=True)
......
......@@ -171,10 +171,7 @@ class PlaceholderRange:
@cached_property
def embeds_cumsum(self) -> torch.Tensor | None:
if self.is_embed is None:
return None
return self.is_embed.cumsum(dim=0)
return None if self.is_embed is None else self.is_embed.cumsum(dim=0)
@cached_property
def get_num_embeds(self) -> int:
......@@ -308,13 +305,7 @@ def batched_tensors_equal(a: BatchedTensorInputs, b: BatchedTensorInputs) -> boo
Equality check between
[`BatchedTensorInputs`][vllm.multimodal.inputs.BatchedTensorInputs] objects.
"""
for k in a:
if k not in b:
return False
if not nested_tensors_equal(a[k], b[k]):
return False
return True
return all(k in b and nested_tensors_equal(a[k], b[k]) for k in a)
@dataclass
......
......@@ -487,10 +487,8 @@ class EngineCore:
request_ids = []
while not self.aborts_queue.empty():
ids = self.aborts_queue.get_nowait()
if isinstance(ids, str):
# Should be a list here, but also handle string just in case.
ids = (ids,)
request_ids.extend(ids)
request_ids.extend((ids,) if isinstance(ids, str) else ids)
# More efficient to abort all as a single batch.
self.abort_requests(request_ids)
......
......@@ -618,7 +618,7 @@ class InputProcessor:
tokenizer = self.tokenizer
if tokenizer is not None:
max_input_id = max(prompt_ids or [], default=0)
max_input_id = max(prompt_ids or (), default=0)
# NOTE: tokenizer.max_token_id is the tokenizer’s vocab size while
# self.model_config.get_vocab_size() is the model’s vocab size.
......
......@@ -339,10 +339,7 @@ class RequestState:
stop_reason=stop_reason if finished else None,
)
def _new_pooling_output(
self,
pooling_output: torch.Tensor,
) -> PoolingOutput:
def _new_pooling_output(self, pooling_output: torch.Tensor) -> PoolingOutput:
return PoolingOutput(data=pooling_output)
......@@ -695,9 +692,7 @@ class OutputProcessor:
assert req_state.stats is not None
iteration_stats.update_from_finished_request(
finish_reason=finish_reason,
num_prompt_tokens=length_from_prompt_token_ids_or_embeds(
req_state.prompt_token_ids, req_state.prompt_embeds
),
num_prompt_tokens=req_state.prompt_len,
max_tokens_param=req_state.max_tokens_param,
req_stats=req_state.stats,
num_cached_tokens=req_state.num_cached_tokens,
......
......@@ -695,7 +695,7 @@ class WorkerProc:
worker = None
# tuple[Connection, Connection]
reader, ready_writer = kwargs.pop("ready_pipe")
death_pipe = kwargs.pop("death_pipe", None)
death_pipe: Connection | None = kwargs.pop("death_pipe", None)
shutdown_event = threading.Event()
# Start death monitoring thread if death_pipe is provided
if death_pipe is not None:
......
......@@ -211,8 +211,7 @@ class Request:
def get_num_encoder_embeds(self, input_id: int) -> int:
assert input_id < len(self.mm_features)
num_embeds = self.mm_features[input_id].mm_position.get_num_embeds
return num_embeds
return self.mm_features[input_id].mm_position.get_num_embeds
def record_event(
self,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools
import multiprocessing
from collections.abc import Iterable
from concurrent.futures import Future, ThreadPoolExecutor
from typing import TYPE_CHECKING
......@@ -172,7 +174,7 @@ class StructuredOutputManager:
def _fill_bitmasks(
self,
batch: list[tuple[StructuredOutputGrammar, int, bool]],
batch: Iterable[tuple[StructuredOutputGrammar, int, bool]],
) -> None:
assert self._grammar_bitmask is not None
for grammar, index, apply_bitmask in batch:
......@@ -265,16 +267,16 @@ class StructuredOutputManager:
apply_bitmask = self.should_fill_bitmask(request)
state_advancements = 0
req_tokens = scheduled_spec_decode_tokens.get(req_id, [])
for i, token in enumerate(req_tokens + [None]):
req_tokens = scheduled_spec_decode_tokens.get(req_id, ())
for token in itertools.chain(req_tokens, (None,)):
self._fill_bitmasks(
[
(
(
structured_output_request.grammar,
cumulative_index,
apply_bitmask,
),
)
]
)
if (
......
......@@ -28,12 +28,9 @@ class StructuredOutputRequest:
if sampling_params is None:
return None
params = sampling_params.structured_outputs
if params:
if params.all_constraints_none():
if not params or params.all_constraints_none():
return None
else:
return StructuredOutputRequest(params=params)
return None
def _check_grammar_completion(self) -> bool:
# NOTE: We have to lazy import to gate circular imports
......
......@@ -829,7 +829,7 @@ class InputBatch:
presence_penalties=self.presence_penalties[:num_reqs],
repetition_penalties=self.repetition_penalties[:num_reqs],
output_token_ids=output_token_ids,
spec_token_ids=cast(list[list[int]], self.spec_token_ids),
spec_token_ids=self.spec_token_ids,
no_penalties=self.no_penalties,
allowed_token_ids_mask=allowed_token_ids_mask,
bad_words_token_ids=self.bad_words_token_ids,
......
......@@ -1026,7 +1026,7 @@ class GPUModelRunner(
each sequence, and a shifting is done during the next iteration
based on the number of accepted tokens.
"""
if not self.model_config.is_hybrid or not self.speculative_config:
if not self.speculative_config or not self.model_config.is_hybrid:
return
# Find the number of accepted tokens for each sequence.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment