# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # Datastructures defining a GPU input batch from dataclasses import dataclass, field from typing import Optional, cast import numpy as np import torch from vllm import envs from vllm.lora.request import LoRARequest from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams, SamplingType from vllm.utils import swap_dict_values from vllm.v1.outputs import LogprobsTensors from vllm.v1.pool.metadata import PoolingMetadata from vllm.v1.sample.logits_processor import (BatchUpdateBuilder, MoveDirectionality, init_builtin_logitsprocs) from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.spec_decode.utils import is_spec_decode_unsupported from vllm.v1.utils import copy_slice from vllm.v1.worker.block_table import MultiGroupBlockTable @dataclass class CachedRequestState: req_id: str prompt_token_ids: list[int] mm_inputs: list[MultiModalKwargs] mm_positions: list[PlaceholderRange] sampling_params: Optional[SamplingParams] pooling_params: Optional[PoolingParams] generator: Optional[torch.Generator] block_ids: tuple[list[int], ...] num_computed_tokens: int num_kv_tokens: int output_token_ids: list[int] spec_token_ids: list[int] = None mrope_positions: Optional[torch.Tensor] = None mrope_position_delta: Optional[int] = None lora_request: Optional[LoRARequest] = None # Lazily populated when `VLLM_V1_FAST_TOKEN_ID_COPY` is enabled. _prompt_token_ids_np: Optional[np.ndarray] = field(default=None, init=False, repr=False, compare=False) # Chunked prefill (scheme 3): cached prompt compaction plan. # Computed on the last prompt chunk; applied before the first decode step. kv_compression_prompt_idx_sorted: Optional[torch.Tensor] = None # [K] int32 kv_compression_prompt_keep_len: Optional[int] = None kv_compression_prompt_prompt_len: Optional[int] = None def __post_init__(self): self.num_prompt_tokens = len(self.prompt_token_ids) @property def num_tokens(self) -> int: return self.num_prompt_tokens + len(self.output_token_ids) def get_token_id(self, idx: int) -> int: if idx < self.num_prompt_tokens: return self.prompt_token_ids[idx] else: return self.output_token_ids[idx - self.num_prompt_tokens] class InputBatch: def __init__( self, max_num_reqs: int, max_model_len: int, max_num_batched_tokens: int, device: torch.device, pin_memory: bool, vocab_size: int, block_sizes: list[int], # The block_size of each kv cache group is_spec_decode: bool = False, logits_processing_needs_token_ids: bool = False, ): self.is_spec_decode = is_spec_decode self.max_num_reqs = max_num_reqs self.max_model_len = max_model_len self.max_num_batched_tokens = max_num_batched_tokens self.device = device self.pin_memory = pin_memory self.vocab_size = vocab_size self.logits_processing_needs_token_ids = ( logits_processing_needs_token_ids) self._req_ids: list[Optional[str]] = [] self.req_id_to_index: dict[str, int] = {} # TODO(woosuk): This buffer could be too large if max_model_len is big. # Find a way to reduce the CPU memory usage. # This buffer is not directly transferred to the GPU, so it does not # need to be pinned. self.token_ids_cpu_tensor = torch.zeros( (max_num_reqs, max_model_len), device="cpu", dtype=torch.int32, pin_memory=False, ) self.token_ids_cpu = self.token_ids_cpu_tensor.numpy() self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32) self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32) self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32) self.num_computed_tokens_cpu_tensor = torch.zeros( (max_num_reqs, ), device="cpu", dtype=torch.int32, pin_memory=pin_memory, ) self.num_computed_tokens_cpu = \ self.num_computed_tokens_cpu_tensor.numpy() self.num_kv_tokens_cpu_tensor = torch.zeros( (max_num_reqs, ), device="cpu", dtype=torch.int32, pin_memory=pin_memory, ) self.num_kv_tokens_cpu = self.num_kv_tokens_cpu_tensor.numpy() # Block table. self.block_table = MultiGroupBlockTable( max_num_reqs=max_num_reqs, max_model_len=max_model_len, max_num_batched_tokens=max_num_batched_tokens, pin_memory=pin_memory, device=device, block_sizes=block_sizes, ) # Sampling-related. self.temperature = torch.empty((max_num_reqs, ), dtype=torch.float32, device=device) self.temperature_cpu_tensor = torch.empty((max_num_reqs, ), dtype=torch.float32, device="cpu", pin_memory=pin_memory) self.temperature_cpu = self.temperature_cpu_tensor.numpy() self.greedy_reqs: set[str] = set() self.random_reqs: set[str] = set() self.top_p = torch.empty((max_num_reqs, ), dtype=torch.float32, device=device) self.top_p_cpu_tensor = torch.empty((max_num_reqs, ), dtype=torch.float32, device="cpu", pin_memory=pin_memory) self.top_p_cpu = self.top_p_cpu_tensor.numpy() self.top_p_reqs: set[str] = set() self.top_k = torch.empty((max_num_reqs, ), dtype=torch.int32, device=device) self.top_k_cpu_tensor = torch.empty((max_num_reqs, ), dtype=torch.int32, device="cpu", pin_memory=pin_memory) self.top_k_cpu = self.top_k_cpu_tensor.numpy() self.top_k_reqs: set[str] = set() # IDs of requests which do not support spec decoding self.spec_decode_unsupported_reqs: set[str] = set() # Frequency penalty related data structures self.frequency_penalties = torch.empty((max_num_reqs, ), dtype=torch.float, device=device) self.frequency_penalties_cpu_tensor = torch.empty( (max_num_reqs, ), dtype=torch.float, device="cpu", pin_memory=pin_memory) self.frequency_penalties_cpu = \ self.frequency_penalties_cpu_tensor.numpy() self.frequency_penalties_reqs: set[str] = set() # Presence penalty related data structures self.presence_penalties = torch.empty((max_num_reqs, ), dtype=torch.float, device=device) self.presence_penalties_cpu_tensor = torch.empty((max_num_reqs, ), dtype=torch.float, device="cpu", pin_memory=pin_memory) self.presence_penalties_cpu = self.presence_penalties_cpu_tensor.numpy( ) self.presence_penalties_reqs: set[str] = set() # Repetition penalty related data structures self.repetition_penalties = torch.empty((max_num_reqs, ), dtype=torch.float, device=device) self.repetition_penalties_cpu_tensor = torch.empty( (max_num_reqs, ), dtype=torch.float, device="cpu", pin_memory=pin_memory) self.repetition_penalties_cpu = \ self.repetition_penalties_cpu_tensor.numpy() self.repetition_penalties_reqs: set[str] = set() # Track whether sampling metadata is currently expanded to # per-token shape (spec decode reject path). self._sampling_metadata_is_expanded = False # lora related self.request_lora_mapping = np.zeros((self.max_num_reqs, ), dtype=np.int32) self.lora_id_to_request_ids: dict[int, set[str]] = {} self.lora_id_to_lora_request: dict[int, LoRARequest] = {} # req_index -> generator # NOTE(woosuk): The indices of the requests that do not have their own # generator should not be included in the dictionary. self.generators: dict[int, torch.Generator] = {} self.num_logprobs: dict[str, int] = {} # NOTE(rob): num_prompt_logprobs only includes reqs # that are currently in the prefill phase. self.num_prompt_logprobs: dict[str, int] = {} # To accumulate prompt logprobs tensor chunks across prefill steps. self.in_progress_prompt_logprobs_cpu: dict[str, LogprobsTensors] = {} # Internal representation of per-step batch state changes, used for # reordering persistent batch and generating logitsprocs batch state # updates. Should reset each step. self.batch_update_builder = BatchUpdateBuilder() # Define logits processors. # TODO(andy): logits processor list should be extensible via engine # constructor argument; for now the list is fixed. self.logitsprocs = init_builtin_logitsprocs( pin_memory_available=pin_memory, max_num_reqs=max_num_reqs + 1, device=device) # TODO convert this to LogitsProcessor self.has_allowed_token_ids: set[str] = set() # NOTE(lufang): In the mask tensor, if the corresponding token allowed, # the value is False. Since we use masked_fill_ to set -inf. self.allowed_token_ids_mask: Optional[torch.Tensor] = None self.allowed_token_ids_mask_cpu_tensor: Optional[torch.Tensor] = None # req_index -> bad_words_token_ids self.bad_words_token_ids: dict[int, list[list[int]]] = {} self.req_output_token_ids: list[Optional[list[int]]] = [] # This is updated each time the batch constituents change. self.sampling_metadata = self._make_sampling_metadata() self.pooling_params: dict[str, PoolingParams] = {} @property def req_ids(self) -> list[str]: # None elements should only be present transiently # while performing state updates to the batch. return cast(list[str], self._req_ids) def _get_next_add_index(self) -> int: if (req_index := self.batch_update_builder.pop_removed()) is not None: # Fill the empty index. return req_index # Append to end return self.num_reqs def _register_add_request(self, request: "CachedRequestState") -> int: """Track add-request operations""" req_index = self._get_next_add_index() assert req_index < self.max_num_reqs params = (request.sampling_params if request.sampling_params else request.pooling_params) self.batch_update_builder.added.append( (req_index, params, request.output_token_ids)) return req_index def add_request( self, request: "CachedRequestState", ) -> int: req_index = self._register_add_request(request) req_id = request.req_id if req_index == len(self._req_ids): self._req_ids.append(req_id) self.req_output_token_ids.append(request.output_token_ids) else: self._req_ids[req_index] = req_id self.req_output_token_ids[req_index] = request.output_token_ids self.req_id_to_index[req_id] = req_index # Copy the prompt token ids and output token ids. # OPTIMIZATION: Use np.copyto with pre-converted NumPy arrays # instead of slice assignment. This avoids the ~550 µs overhead # of converting Python list to NumPy array each time. num_prompt_tokens = len(request.prompt_token_ids) self.num_prompt_tokens[req_index] = num_prompt_tokens if not envs.VLLM_V1_FAST_TOKEN_ID_COPY: self.token_ids_cpu[ req_index, :num_prompt_tokens] = request.prompt_token_ids start_idx = num_prompt_tokens end_idx = start_idx + len(request.output_token_ids) self.token_ids_cpu[req_index, start_idx:end_idx] = request.output_token_ids else: prompt_token_ids_np = getattr(request, "_prompt_token_ids_np", None) rebuild_prompt_cache = True if prompt_token_ids_np is not None: try: rebuild_prompt_cache = (prompt_token_ids_np.dtype != np.int32 or prompt_token_ids_np.size != num_prompt_tokens) except Exception: rebuild_prompt_cache = True if rebuild_prompt_cache: prompt_token_ids_np = np.asarray(request.prompt_token_ids, dtype=np.int32) try: request._prompt_token_ids_np = prompt_token_ids_np except Exception: pass np.copyto( self.token_ids_cpu[req_index, :num_prompt_tokens], prompt_token_ids_np, casting='no', ) start_idx = num_prompt_tokens output_token_ids_np = np.asarray(request.output_token_ids, dtype=np.int32) end_idx = start_idx + output_token_ids_np.size np.copyto( self.token_ids_cpu[req_index, start_idx:end_idx], output_token_ids_np, casting='no', ) num_spec_tokens = 0 if request.spec_token_ids != None: num_spec_tokens = len(request.spec_token_ids) self.token_ids_cpu[req_index, end_idx:end_idx + num_spec_tokens] = request.spec_token_ids # Number of token ids in token_ids_cpu. # NOTE(woosuk): This may include spec decode tokens. self.num_tokens[req_index] = request.num_tokens + num_spec_tokens # Number of tokens without spec decode tokens. self.num_tokens_no_spec[req_index] = request.num_tokens self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens self.num_kv_tokens_cpu[req_index] = request.num_kv_tokens self.block_table.add_row(request.block_ids, req_index) if sampling_params := request.sampling_params: if (self.is_spec_decode and is_spec_decode_unsupported(sampling_params)): self.spec_decode_unsupported_reqs.add(req_id) if sampling_params.sampling_type == SamplingType.GREEDY: # Avoid later division by zero. self.temperature_cpu[req_index] = -1.0 self.greedy_reqs.add(req_id) else: self.temperature_cpu[req_index] = sampling_params.temperature self.random_reqs.add(req_id) self.top_p_cpu[req_index] = sampling_params.top_p if sampling_params.top_p < 1: self.top_p_reqs.add(req_id) top_k = sampling_params.top_k if 0 < top_k < self.vocab_size: self.top_k_reqs.add(req_id) else: top_k = self.vocab_size self.top_k_cpu[req_index] = top_k self.frequency_penalties_cpu[ req_index] = sampling_params.frequency_penalty if sampling_params.frequency_penalty != 0.0: self.frequency_penalties_reqs.add(req_id) self.presence_penalties_cpu[ req_index] = sampling_params.presence_penalty if sampling_params.presence_penalty != 0.0: self.presence_penalties_reqs.add(req_id) self.repetition_penalties_cpu[ req_index] = sampling_params.repetition_penalty if sampling_params.repetition_penalty != 1.0: self.repetition_penalties_reqs.add(req_id) # NOTE(woosuk): self.generators should not include the requests that # do not have their own generator. if request.generator is not None: self.generators[req_index] = request.generator if sampling_params.logprobs is not None: self.num_logprobs[req_id] = sampling_params.logprobs if sampling_params.prompt_logprobs is not None: self.num_prompt_logprobs[ req_id] = sampling_params.prompt_logprobs if sampling_params.allowed_token_ids: self.has_allowed_token_ids.add(req_id) if self.allowed_token_ids_mask_cpu_tensor is None: # Lazy allocation for this tensor, which can be large. # False means we don't fill with -inf. self.allowed_token_ids_mask = torch.zeros( self.max_num_reqs, self.vocab_size, dtype=torch.bool, device=self.device) self.allowed_token_ids_mask_cpu_tensor = torch.zeros( self.max_num_reqs, self.vocab_size, dtype=torch.bool, device="cpu") self.allowed_token_ids_mask_cpu_tensor[req_index] = True # False means we don't fill with -inf. self.allowed_token_ids_mask_cpu_tensor[req_index][ sampling_params.allowed_token_ids] = False if sampling_params.bad_words_token_ids: self.bad_words_token_ids[ req_index] = sampling_params.bad_words_token_ids else: assert request.pooling_params is not None self.pooling_params[req_id] = request.pooling_params # Add request lora ID if request.lora_request: lora_id = request.lora_request.lora_int_id if lora_id not in self.lora_id_to_request_ids: self.lora_id_to_request_ids[lora_id] = set() self.request_lora_mapping[req_index] = lora_id self.lora_id_to_request_ids[lora_id].add(request.req_id) self.lora_id_to_lora_request[lora_id] = request.lora_request else: # No LoRA self.request_lora_mapping[req_index] = 0 return req_index def remove_request(self, req_id: str) -> Optional[int]: """This method must always be followed by a call to condense(). Args: req_id: request to remove Returns: Removed request index, or `None` if `req_id` not recognized """ req_index = self.req_id_to_index.pop(req_id, None) if req_index is None: return None self.batch_update_builder.removed_append(req_index) self._req_ids[req_index] = None self.req_output_token_ids[req_index] = None self.greedy_reqs.discard(req_id) self.random_reqs.discard(req_id) self.top_p_reqs.discard(req_id) self.top_k_reqs.discard(req_id) self.spec_decode_unsupported_reqs.discard(req_id) self.frequency_penalties_reqs.discard(req_id) self.presence_penalties_reqs.discard(req_id) self.repetition_penalties_reqs.discard(req_id) self.generators.pop(req_index, None) self.num_logprobs.pop(req_id, None) self.num_prompt_logprobs.pop(req_id, None) self.in_progress_prompt_logprobs_cpu.pop(req_id, None) # LoRA lora_id = self.request_lora_mapping[req_index] if lora_id != 0: self.lora_id_to_request_ids[lora_id].discard(req_id) if len(self.lora_id_to_request_ids[lora_id]) == 0: self.lora_id_to_request_ids.pop(lora_id) self.lora_id_to_lora_request.pop(lora_id) self.request_lora_mapping[req_index] = 0 self.has_allowed_token_ids.discard(req_id) if self.allowed_token_ids_mask_cpu_tensor is not None: # False means we don't fill with -inf. self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False) self.bad_words_token_ids.pop(req_index, None) self.pooling_params.pop(req_id, None) return req_index def swap_states(self, i1: int, i2: int) -> None: self.batch_update_builder.moved.append( (i1, i2, MoveDirectionality.SWAP)) old_id_i1 = self._req_ids[i1] old_id_i2 = self._req_ids[i2] self._req_ids[i1], self._req_ids[i2] =\ self._req_ids[i2], self._req_ids[i1] # noqa self.req_output_token_ids[i1], self.req_output_token_ids[i2] =\ self.req_output_token_ids[i2], self.req_output_token_ids[i1] assert old_id_i1 is not None and old_id_i2 is not None self.req_id_to_index[old_id_i1], self.req_id_to_index[old_id_i2] =\ self.req_id_to_index[old_id_i2], self.req_id_to_index[old_id_i1] self.num_tokens[i1], self.num_tokens[i2] =\ self.num_tokens[i2], self.num_tokens[i1] self.num_tokens_no_spec[i1], self.num_tokens_no_spec[i2] =\ self.num_tokens_no_spec[i2], self.num_tokens_no_spec[i1] self.num_prompt_tokens[i1], self.num_prompt_tokens[i2] =\ self.num_prompt_tokens[i2], self.num_prompt_tokens[i1] self.num_computed_tokens_cpu[i1], self.num_computed_tokens_cpu[i2] =\ self.num_computed_tokens_cpu[i2], self.num_computed_tokens_cpu[i1] self.num_kv_tokens_cpu[i1], self.num_kv_tokens_cpu[i2] =\ self.num_kv_tokens_cpu[i2], self.num_kv_tokens_cpu[i1] self.temperature_cpu[i1], self.temperature_cpu[i2] =\ self.temperature_cpu[i2], self.temperature_cpu[i1] self.top_p_cpu[i1], self.top_p_cpu[i2] =\ self.top_p_cpu[i2], self.top_p_cpu[i1] self.top_k_cpu[i1], self.top_k_cpu[i2] =\ self.top_k_cpu[i2], self.top_k_cpu[i1] self.frequency_penalties_cpu[i1], self.frequency_penalties_cpu[i2] =\ self.frequency_penalties_cpu[i2], self.frequency_penalties_cpu[i1] self.presence_penalties_cpu[i1], self.presence_penalties_cpu[i2] =\ self.presence_penalties_cpu[i2], self.presence_penalties_cpu[i1] self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] =\ self.repetition_penalties_cpu[i2], self.repetition_penalties_cpu[i1] # NOTE: the following is unsafe # self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\ # self.token_ids_cpu[i2, ...], self.token_ids_cpu[i1, ...] # instead, we need to temporiarily copy the data for one of the indices # TODO(lucas): optimize this by only copying valid indices tmp = self.token_ids_cpu[i1, ...].copy() self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...] self.token_ids_cpu[i2, ...] = tmp swap_dict_values(self.generators, i1, i2) swap_dict_values(self.bad_words_token_ids, i1, i2) self.request_lora_mapping[i1], self.request_lora_mapping[i2] =\ self.request_lora_mapping[i2], self.request_lora_mapping[i1] if self.allowed_token_ids_mask_cpu_tensor is not None: self.allowed_token_ids_mask_cpu_tensor[i1], \ self.allowed_token_ids_mask_cpu_tensor[i2] =\ self.allowed_token_ids_mask_cpu_tensor[i2], \ self.allowed_token_ids_mask_cpu_tensor[i1] self.block_table.swap_row(i1, i2) def condense(self) -> None: """Slide non-empty requests down into lower, empty indices. Any consecutive empty indices at the very end of the list are not filled. Args: empty_req_indices: empty indices which may be filled. Returns: swaps: list of (from,to) swap tuples for moved requests empty_req_indices: indices not filled by condensation """ if not (empty_req_indices := self.batch_update_builder.removed): # All removed requests were replaced by added requests, or else no # requests were removed at all. No condense() needed return num_reqs = self.num_reqs if num_reqs == 0: # The batched states are empty. self._req_ids.clear() self.req_output_token_ids.clear() return # NOTE(woosuk): This function assumes that the empty_req_indices # is sorted in descending order. last_req_index = num_reqs + len(empty_req_indices) - 1 while empty_req_indices: # Find the largest non-empty index. while last_req_index in empty_req_indices: last_req_index -= 1 # Find the smallest empty index. empty_index = self.batch_update_builder.peek_removed() assert empty_index is not None if empty_index >= last_req_index: break # Move active request down into empty request # index. self.batch_update_builder.pop_removed() self.batch_update_builder.moved.append( (last_req_index, empty_index, MoveDirectionality.UNIDIRECTIONAL)) req_id = self._req_ids[last_req_index] output_token_ids = self.req_output_token_ids[last_req_index] assert req_id is not None self._req_ids[empty_index] = req_id self._req_ids[last_req_index] = None self.req_output_token_ids[empty_index] = output_token_ids self.req_output_token_ids[last_req_index] = None self.req_id_to_index[req_id] = empty_index num_tokens = self.num_tokens[last_req_index] self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[ last_req_index, :num_tokens] self.num_tokens[empty_index] = num_tokens self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[ last_req_index] self.num_prompt_tokens[empty_index] = self.num_prompt_tokens[ last_req_index] self.num_computed_tokens_cpu[ empty_index] = self.num_computed_tokens_cpu[last_req_index] self.num_kv_tokens_cpu[ empty_index] = self.num_kv_tokens_cpu[last_req_index] self.block_table.move_row(last_req_index, empty_index) self.temperature_cpu[empty_index] = self.temperature_cpu[ last_req_index] self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index] self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index] self.frequency_penalties_cpu[ empty_index] = self.frequency_penalties_cpu[last_req_index] self.presence_penalties_cpu[ empty_index] = self.presence_penalties_cpu[last_req_index] self.repetition_penalties_cpu[ empty_index] = self.repetition_penalties_cpu[last_req_index] generator = self.generators.pop(last_req_index, None) if generator is not None: self.generators[empty_index] = generator self.request_lora_mapping[empty_index] = self.request_lora_mapping[ last_req_index] # TODO convert these to LogitsProcessors if self.allowed_token_ids_mask_cpu_tensor is not None: self.allowed_token_ids_mask_cpu_tensor[ empty_index] = self.allowed_token_ids_mask_cpu_tensor[ last_req_index] bad_words_token_ids = self.bad_words_token_ids.pop( last_req_index, None) if bad_words_token_ids is not None: self.bad_words_token_ids[empty_index] = bad_words_token_ids # Decrement last_req_index since it is now empty. last_req_index -= 1 # Trim lists to the batch size. del self._req_ids[self.num_reqs:] del self.req_output_token_ids[self.num_reqs:] def refresh_metadata(self, repeat_counts: Optional[torch.Tensor] = None): """Apply batch updates, reset input batch at end of step * Apply batch add/remove/permute to logits procs' states * If batch state is modified, update sampling metadata """ batch_update = self.batch_update_builder.get_and_reset(self.num_reqs) for logit_proc in self.logitsprocs.all: logit_proc.update_state(batch_update) needs_rebuild = (batch_update or repeat_counts is not None or self._sampling_metadata_is_expanded) if needs_rebuild: if repeat_counts is None: self.sampling_metadata = self._make_sampling_metadata() else: self.sampling_metadata = self._make_sampling_metadata_expanded( repeat_counts) self._sampling_metadata_is_expanded = repeat_counts is not None # Expanded metadata is built on demand; do not cache a copy here. def _make_sampling_metadata(self) -> SamplingMetadata: num_reqs = self.num_reqs if not self.all_greedy: temperature = copy_slice(self.temperature_cpu_tensor, self.temperature, num_reqs) else: temperature = None if not self.no_top_p: copy_slice(self.top_p_cpu_tensor, self.top_p, num_reqs) if not self.no_top_k: copy_slice(self.top_k_cpu_tensor, self.top_k, num_reqs) if not self.no_penalties: # Since syncing these tensors is expensive only copy them # if necessary i.e. if there are requests which require # penalties to be applied during sampling. copy_slice(self.frequency_penalties_cpu_tensor, self.frequency_penalties, num_reqs) copy_slice(self.presence_penalties_cpu_tensor, self.presence_penalties, num_reqs) copy_slice(self.repetition_penalties_cpu_tensor, self.repetition_penalties, num_reqs) needs_prompt_token_ids = (not self.no_penalties or (self.num_reqs > 0 and self.logits_processing_needs_token_ids)) if needs_prompt_token_ids: # The prompt tokens are used only for applying penalties or # step pooling during the sampling/pooling process. # Hence copy these tensors only when there are requests which # need penalties/step_pooler to be applied. prompt_token_ids = self._make_prompt_token_ids_tensor() else: prompt_token_ids = None allowed_token_ids_mask: Optional[torch.Tensor] = None if not self.no_allowed_token_ids: assert self.allowed_token_ids_mask is not None copy_slice(self.allowed_token_ids_mask_cpu_tensor, self.allowed_token_ids_mask, num_reqs) allowed_token_ids_mask = self.allowed_token_ids_mask[:num_reqs] # Host-side summaries to avoid device synchronization in sampling # fast paths (e.g. reduced top-k/top-p sampling). max_top_k = None has_any_no_top_k = False if not self.no_top_k and num_reqs > 0: top_k_cpu = self.top_k_cpu[:num_reqs] max_top_k = int(top_k_cpu.max()) has_any_no_top_k = bool((top_k_cpu == self.vocab_size).any()) return SamplingMetadata( temperature=temperature, all_greedy=self.all_greedy, all_random=self.all_random, top_p=None if self.no_top_p else self.top_p[:num_reqs], top_k=None if self.no_top_k else self.top_k[:num_reqs], generators=self.generators, max_num_logprobs=self.max_num_logprobs, prompt_token_ids=prompt_token_ids, frequency_penalties=self.frequency_penalties[:num_reqs], presence_penalties=self.presence_penalties[:num_reqs], repetition_penalties=self.repetition_penalties[:num_reqs], output_token_ids=cast(list[list[int]], self.req_output_token_ids), no_penalties=self.no_penalties, allowed_token_ids_mask=allowed_token_ids_mask, bad_words_token_ids=self.bad_words_token_ids, logitsprocs=self.logitsprocs, max_top_k=max_top_k, has_any_no_top_k=has_any_no_top_k, ) def _make_sampling_metadata_expanded( self, repeat_counts: torch.Tensor ) -> SamplingMetadata: num_reqs = self.num_reqs repeat_counts_cpu = repeat_counts all_greedy = self.all_greedy all_random = self.all_random # For reject-sampling optimization, force greedy sampling to keep # rejection sampler assumptions (per-request shapes) intact. def _expand_cpu_to_gpu( t: Optional[torch.Tensor], *, dtype: Optional[torch.dtype] = None, ) -> Optional[torch.Tensor]: if t is None: return None base = t[:num_reqs] if repeat_counts_cpu is not None: base = base.repeat_interleave(repeat_counts_cpu, dim=0) return base.to(device=self.device, dtype=dtype if dtype is not None else None, non_blocking=True) needs_prompt_token_ids = (not self.no_penalties or (self.num_reqs > 0 and self.logits_processing_needs_token_ids)) if needs_prompt_token_ids: # The prompt tokens are used only for applying penalties or # step pooling during the sampling/pooling process. # Hence copy these tensors only when there are requests which # need penalties/step_pooler to be applied. prompt_token_ids = self._make_prompt_token_ids_tensor( repeat_counts_cpu) else: prompt_token_ids = None allowed_token_ids_mask: Optional[torch.Tensor] = None if not self.no_allowed_token_ids: assert self.allowed_token_ids_mask is not None allowed_token_ids_mask = self.allowed_token_ids_mask_cpu_tensor # Expand per-request metadata to per-token shape when repeat_counts # is provided (spec decode reject-sampling path). top_p_cpu = None if self.no_top_p else self.top_p_cpu_tensor top_k_cpu = None if self.no_top_k else self.top_k_cpu_tensor repeat_list = repeat_counts_cpu.tolist() row_offsets: list[int] = [] total_rows = 0 for repeat in repeat_list: row_offsets.append(total_rows) total_rows += int(repeat) expanded_output_token_ids: list[list[int]] = [] expanded_bad_words_token_ids: dict[int, list[list[int]]] = {} expanded_generators: dict[int, torch.Generator] = {} row_idx = 0 for req_idx in range(num_reqs): repeat = int(repeat_list[req_idx]) if repeat <= 0: continue output_tokens = self.req_output_token_ids[req_idx] assert output_tokens is not None bad_words = self.bad_words_token_ids.get(req_idx) generator = self.generators.get(req_idx) for _ in range(repeat): expanded_output_token_ids.append(output_tokens) if bad_words is not None: expanded_bad_words_token_ids[row_idx] = bad_words if generator is not None: expanded_generators[row_idx] = generator row_idx += 1 max_top_k = None has_any_no_top_k = False if not self.no_top_k and num_reqs > 0: top_k_cpu = self.top_k_cpu[:num_reqs] max_top_k = int(top_k_cpu.max()) has_any_no_top_k = bool((top_k_cpu == self.vocab_size).any()) return SamplingMetadata( temperature=_expand_cpu_to_gpu( None if all_greedy else self.temperature_cpu_tensor), all_greedy=all_greedy, all_random=all_random, top_p=_expand_cpu_to_gpu(top_p_cpu), top_k=_expand_cpu_to_gpu(top_k_cpu, dtype=torch.int32), generators=expanded_generators, max_num_logprobs=self.max_num_logprobs, prompt_token_ids=prompt_token_ids, frequency_penalties=( None if self.no_penalties else _expand_cpu_to_gpu( self.frequency_penalties_cpu_tensor)), presence_penalties=( None if self.no_penalties else _expand_cpu_to_gpu( self.presence_penalties_cpu_tensor)), repetition_penalties=( None if self.no_penalties else _expand_cpu_to_gpu( self.repetition_penalties_cpu_tensor)), output_token_ids=expanded_output_token_ids, no_penalties=self.no_penalties, allowed_token_ids_mask=_expand_cpu_to_gpu( allowed_token_ids_mask, dtype=torch.bool), bad_words_token_ids=expanded_bad_words_token_ids, logitsprocs=self.logitsprocs, max_top_k=max_top_k, has_any_no_top_k=has_any_no_top_k, ) @property def pooling_metadata(self) -> PoolingMetadata: if len(self.pooling_params) == 0: pooling_params = [] else: # Note, for now this assumes that all request in the batch # are either sampling or pooling requests assert len(self.req_ids) == len(self.pooling_params) pooling_params = [ self.pooling_params[req_id] for req_id in self.req_ids ] return PoolingMetadata( prompt_lens=torch.from_numpy( self.num_prompt_tokens[:self.num_reqs]).to(self.device), prompt_token_ids=self.sampling_metadata.prompt_token_ids, pooling_params=pooling_params, ) def _make_prompt_token_ids_tensor( self, repeat_counts_cpu: Optional[torch.Tensor] = None ) -> torch.Tensor: max_prompt_len = self.num_prompt_tokens[:self.num_reqs].max() prompt_token_ids_cpu_tensor = torch.empty( (self.num_reqs, max_prompt_len), device="cpu", dtype=torch.int64, pin_memory=self.pin_memory, ) prompt_token_ids = prompt_token_ids_cpu_tensor.numpy() prompt_token_ids[:] = self.token_ids_cpu[:self. num_reqs, :max_prompt_len] # Use the value of vocab_size as a pad since we don't have a # token_id of this value. for i in range(self.num_reqs): prompt_token_ids[i, self.num_prompt_tokens[i]:] = self.vocab_size if repeat_counts_cpu is not None: prompt_token_ids_cpu_tensor = prompt_token_ids_cpu_tensor \ .repeat_interleave(repeat_counts_cpu, dim=0) return prompt_token_ids_cpu_tensor.to(device=self.device, non_blocking=True) def make_lora_inputs( self, num_scheduled_tokens: np.ndarray ) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]: """ Given the num_scheduled_tokens for each request in the batch, return datastructures used to activate the current LoRAs. Returns: 1. prompt_lora_mapping: A tuple of size self.num_reqs where, prompt_lora_mapping[i] is the LoRA id to use for the ith prompt. 2. token_lora_mapping: A tuple of size np.sum(num_scheduled_tokens) where, token_lora_mapping[i] is the LoRA id to use for ith token. 3. lora_requests: Set of relevant LoRA requests. """ req_lora_mapping = self.request_lora_mapping[:self.num_reqs] prompt_lora_mapping = tuple(req_lora_mapping) token_lora_mapping = tuple( req_lora_mapping.repeat(num_scheduled_tokens)) active_lora_requests: set[LoRARequest] = set( self.lora_id_to_lora_request.values()) return prompt_lora_mapping, token_lora_mapping, active_lora_requests @property def num_reqs(self) -> int: return len(self.req_id_to_index) @property def all_greedy(self) -> bool: return len(self.random_reqs) == 0 @property def all_random(self) -> bool: return len(self.greedy_reqs) == 0 @property def no_top_p(self) -> bool: return len(self.top_p_reqs) == 0 @property def no_top_k(self) -> bool: return len(self.top_k_reqs) == 0 @property def no_penalties(self) -> bool: return (len(self.presence_penalties_reqs) == 0 and len(self.frequency_penalties_reqs) == 0 and len(self.repetition_penalties_reqs) == 0) @property def max_num_logprobs(self) -> Optional[int]: return max(self.num_logprobs.values()) if self.num_logprobs else None @property def no_prompt_logprob(self) -> bool: return not self.num_prompt_logprobs @property def no_allowed_token_ids(self) -> bool: return len(self.has_allowed_token_ids) == 0