Unverified Commit bf7f470b authored by afeldman-nm's avatar afeldman-nm Committed by GitHub
Browse files

[V1] Logits processors extensibility (#19912)


Signed-off-by: default avatarAndrew Feldman <afeldman@redhat.com>
Signed-off-by: default avatarAndrew Feldman <afeld2012@gmail.com>
Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
Signed-off-by: default avatarNick Hill <nhill@redhat.com>
Co-authored-by: default avatarNick Hill <nhill@redhat.com>
Co-authored-by: default avatarAndrew Feldman <afeld2012@gmail.com>
Co-authored-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 4fc722ec
...@@ -18,8 +18,8 @@ from vllm.utils import swap_dict_values ...@@ -18,8 +18,8 @@ from vllm.utils import swap_dict_values
from vllm.v1.outputs import LogprobsTensors from vllm.v1.outputs import LogprobsTensors
from vllm.v1.pool.metadata import PoolingMetadata from vllm.v1.pool.metadata import PoolingMetadata
from vllm.v1.sample.logits_processor import (BatchUpdateBuilder, from vllm.v1.sample.logits_processor import (BatchUpdateBuilder,
MoveDirectionality, LogitsProcessors,
init_builtin_logitsprocs) MoveDirectionality)
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.utils import is_spec_decode_unsupported from vllm.v1.spec_decode.utils import is_spec_decode_unsupported
from vllm.v1.utils import copy_slice from vllm.v1.utils import copy_slice
...@@ -78,8 +78,11 @@ class InputBatch: ...@@ -78,8 +78,11 @@ class InputBatch:
pin_memory: bool, pin_memory: bool,
vocab_size: int, vocab_size: int,
block_sizes: list[int], # The block_size of each kv cache group block_sizes: list[int], # The block_size of each kv cache group
logitsprocs: Optional[LogitsProcessors] = None,
is_spec_decode: bool = False, is_spec_decode: bool = False,
is_pooling_model: bool = False,
): ):
self.is_pooling_model = is_pooling_model
self.is_spec_decode = is_spec_decode self.is_spec_decode = is_spec_decode
self.max_num_reqs = max_num_reqs self.max_num_reqs = max_num_reqs
self.max_model_len = max_model_len self.max_model_len = max_model_len
...@@ -221,14 +224,6 @@ class InputBatch: ...@@ -221,14 +224,6 @@ class InputBatch:
# updates. Should reset each step. # updates. Should reset each step.
self.batch_update_builder = BatchUpdateBuilder() self.batch_update_builder = BatchUpdateBuilder()
# Define logits processors.
# TODO(andy): logits processor list should be extensible via engine
# constructor argument; for now the list is fixed.
self.logitsprocs = init_builtin_logitsprocs(
pin_memory_available=pin_memory,
max_num_reqs=max_num_reqs + 1,
device=device)
# TODO convert this to LogitsProcessor # TODO convert this to LogitsProcessor
self.has_allowed_token_ids: set[str] = set() self.has_allowed_token_ids: set[str] = set()
# NOTE(lufang): In the mask tensor, if the corresponding token allowed, # NOTE(lufang): In the mask tensor, if the corresponding token allowed,
...@@ -244,6 +239,10 @@ class InputBatch: ...@@ -244,6 +239,10 @@ class InputBatch:
self.req_output_token_ids: list[Optional[list[int]]] = [] self.req_output_token_ids: list[Optional[list[int]]] = []
# Store provided logitsprocs. If none are provided, initialize empty
# data structure
self.logitsprocs = logitsprocs or LogitsProcessors()
# This is updated each time the batch constituents change. # This is updated each time the batch constituents change.
self.sampling_metadata = self._make_sampling_metadata() self.sampling_metadata = self._make_sampling_metadata()
...@@ -255,28 +254,35 @@ class InputBatch: ...@@ -255,28 +254,35 @@ class InputBatch:
# while performing state updates to the batch. # while performing state updates to the batch.
return cast(list[str], self._req_ids) return cast(list[str], self._req_ids)
def _get_next_add_index(self) -> int:
if (req_index := self.batch_update_builder.pop_removed()) is not None:
# Fill the empty index.
return req_index
# Append to end
return self.num_reqs
def _register_add_request(self, request: "CachedRequestState") -> int: def _register_add_request(self, request: "CachedRequestState") -> int:
"""Track add-request operations""" """Track add-request operations for logits processors.
req_index = self._get_next_add_index() Not applicable to pooling models.
assert req_index < self.max_num_reqs """
params = (request.sampling_params
if request.sampling_params else request.pooling_params) # Detailed added request metadata is only required for non-pooling
# models, to support logitsprocs
assert request.sampling_params
# Fill the next empty index if there is one.
if (new_req_index := self.batch_update_builder.pop_removed()) is None:
# Append to end otherwise.
new_req_index = self.num_reqs
assert new_req_index < self.max_num_reqs
self.batch_update_builder.added.append( self.batch_update_builder.added.append(
(req_index, params, request.output_token_ids)) (new_req_index, request.sampling_params, request.prompt_token_ids,
return req_index request.output_token_ids))
return new_req_index
def add_request( def add_request(
self, self,
request: "CachedRequestState", request: "CachedRequestState",
) -> int: ) -> int:
req_index = self._register_add_request(request) if not self.is_pooling_model:
# New request index bookkeeping for autoregressive models.
req_index = self._register_add_request(request)
else:
req_index = self.num_reqs
req_id = request.req_id req_id = request.req_id
if req_index == len(self._req_ids): if req_index == len(self._req_ids):
...@@ -411,7 +417,10 @@ class InputBatch: ...@@ -411,7 +417,10 @@ class InputBatch:
req_index = self.req_id_to_index.pop(req_id, None) req_index = self.req_id_to_index.pop(req_id, None)
if req_index is None: if req_index is None:
return None return None
self.batch_update_builder.removed_append(req_index) if not self.is_pooling_model:
# Autoregressive models require bookkeeping of removed requests to
# support logitsprocs.
self.batch_update_builder.removed_append(req_index)
self._req_ids[req_index] = None self._req_ids[req_index] = None
self.req_output_token_ids[req_index] = None self.req_output_token_ids[req_index] = None
...@@ -446,6 +455,8 @@ class InputBatch: ...@@ -446,6 +455,8 @@ class InputBatch:
return req_index return req_index
def swap_states(self, i1: int, i2: int) -> None: def swap_states(self, i1: int, i2: int) -> None:
# For autoregressive models, track detailed request reordering info
# to support logitsprocs
self.batch_update_builder.moved.append( self.batch_update_builder.moved.append(
(i1, i2, MoveDirectionality.SWAP)) (i1, i2, MoveDirectionality.SWAP))
old_id_i1 = self._req_ids[i1] old_id_i1 = self._req_ids[i1]
...@@ -513,11 +524,18 @@ class InputBatch: ...@@ -513,11 +524,18 @@ class InputBatch:
swaps: list of (from,to) swap tuples for moved requests swaps: list of (from,to) swap tuples for moved requests
empty_req_indices: indices not filled by condensation empty_req_indices: indices not filled by condensation
""" """
num_reqs = self.num_reqs
if self.is_pooling_model:
# Will be contiguous in pooling case, just trim the lists.
del self._req_ids[num_reqs:]
del self.req_output_token_ids[num_reqs:]
return
if not (empty_req_indices := self.batch_update_builder.removed): if not (empty_req_indices := self.batch_update_builder.removed):
# All removed requests were replaced by added requests, or else no # All removed requests were replaced by added requests, or else no
# requests were removed at all. No condense() needed # requests were removed at all. No condense() needed
return return
num_reqs = self.num_reqs
if num_reqs == 0: if num_reqs == 0:
# The batched states are empty. # The batched states are empty.
self._req_ids.clear() self._req_ids.clear()
...@@ -541,6 +559,8 @@ class InputBatch: ...@@ -541,6 +559,8 @@ class InputBatch:
# Move active request down into empty request # Move active request down into empty request
# index. # index.
self.batch_update_builder.pop_removed() self.batch_update_builder.pop_removed()
# Autoregressive models require detailed tracking of condense
# operations to support logitsprocs
self.batch_update_builder.moved.append( self.batch_update_builder.moved.append(
(last_req_index, empty_index, (last_req_index, empty_index,
MoveDirectionality.UNIDIRECTIONAL)) MoveDirectionality.UNIDIRECTIONAL))
...@@ -596,15 +616,20 @@ class InputBatch: ...@@ -596,15 +616,20 @@ class InputBatch:
last_req_index -= 1 last_req_index -= 1
# Trim lists to the batch size. # Trim lists to the batch size.
del self._req_ids[self.num_reqs:] del self._req_ids[num_reqs:]
del self.req_output_token_ids[self.num_reqs:] del self.req_output_token_ids[num_reqs:]
def refresh_metadata(self): def refresh_metadata(self):
"""Apply batch updates, reset input batch at end of step """Apply any batch updates to sampling metadata."""
* Apply batch add/remove/permute to logits procs' states if self.is_pooling_model:
* If batch state is modified, update sampling metadata # Batch changes every step for pooling models.
""" self.sampling_metadata = self._make_sampling_metadata()
return
# For non-pooling models - generate and apply logitsprocs update;
# reset batch update tracking.
# Update sampling metadata if batch state is changed.
batch_update = self.batch_update_builder.get_and_reset(self.num_reqs) batch_update = self.batch_update_builder.get_and_reset(self.num_reqs)
for logit_proc in self.logitsprocs.all: for logit_proc in self.logitsprocs.all:
logit_proc.update_state(batch_update) logit_proc.update_state(batch_update)
......
...@@ -68,6 +68,7 @@ from vllm.v1.kv_cache_interface import (AttentionSpec, ...@@ -68,6 +68,7 @@ from vllm.v1.kv_cache_interface import (AttentionSpec,
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors, from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
ModelRunnerOutput) ModelRunnerOutput)
from vllm.v1.pool.metadata import PoolingMetadata from vllm.v1.pool.metadata import PoolingMetadata
from vllm.v1.sample.logits_processor import LogitsProcessors, build_logitsprocs
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.rejection_sampler import RejectionSampler from vllm.v1.sample.rejection_sampler import RejectionSampler
from vllm.v1.sample.sampler import Sampler from vllm.v1.sample.sampler import Sampler
...@@ -80,7 +81,6 @@ from vllm.v1.worker.kv_connector_model_runner_mixin import ( ...@@ -80,7 +81,6 @@ from vllm.v1.worker.kv_connector_model_runner_mixin import (
KVConnectorModelRunnerMixin, KVConnectorOutput) KVConnectorModelRunnerMixin, KVConnectorOutput)
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
from ..sample.logits_processor import LogitsProcessorManager
from .utils import (AttentionGroup, MultiModalBudget, bind_kv_cache, from .utils import (AttentionGroup, MultiModalBudget, bind_kv_cache,
gather_mm_placeholders, initialize_kv_cache_for_kv_sharing, gather_mm_placeholders, initialize_kv_cache_for_kv_sharing,
sanity_check_mm_encoder_outputs, scatter_mm_placeholders) sanity_check_mm_encoder_outputs, scatter_mm_placeholders)
...@@ -221,6 +221,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -221,6 +221,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
vocab_size=self.model_config.get_vocab_size(), vocab_size=self.model_config.get_vocab_size(),
block_sizes=[self.cache_config.block_size], block_sizes=[self.cache_config.block_size],
is_spec_decode=bool(self.vllm_config.speculative_config), is_spec_decode=bool(self.vllm_config.speculative_config),
logitsprocs=build_logitsprocs(
self.vllm_config, self.device, self.pin_memory,
self.is_pooling_model,
self.vllm_config.model_config.logits_processors),
is_pooling_model=self.is_pooling_model,
) )
# TODO(woosuk): Provide an option to tune the max cudagraph batch size. # TODO(woosuk): Provide an option to tune the max cudagraph batch size.
...@@ -2447,7 +2452,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -2447,7 +2452,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
output_token_ids=[[] for _ in range(num_reqs)], output_token_ids=[[] for _ in range(num_reqs)],
allowed_token_ids_mask=None, allowed_token_ids_mask=None,
bad_words_token_ids={}, bad_words_token_ids={},
logitsprocs=LogitsProcessorManager(), logitsprocs=LogitsProcessors(),
) )
try: try:
sampler_output = self.sampler(logits=logits, sampler_output = self.sampler(logits=logits,
...@@ -2968,6 +2973,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -2968,6 +2973,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
vocab_size=self.model_config.get_vocab_size(), vocab_size=self.model_config.get_vocab_size(),
block_sizes=block_sizes, block_sizes=block_sizes,
is_spec_decode=bool(self.vllm_config.speculative_config), is_spec_decode=bool(self.vllm_config.speculative_config),
logitsprocs=self.input_batch.logitsprocs,
is_pooling_model=self.is_pooling_model,
) )
def _allocate_kv_cache_tensors( def _allocate_kv_cache_tensors(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment