from __future__ import annotations # Copyright 2023-2024 SGLang Team # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """ Store information about requests and batches. The following is the flow of data structures for a batch: ScheduleBatch -> ModelWorkerBatch -> ForwardBatch - ScheduleBatch is managed by `scheduler.py::Scheduler`. It contains high-level scheduling data. Most of the data is on the CPU. - ModelWorkerBatch is managed by `tp_worker.py::TpModelWorker`. It is a subset of `ScheduleBatch` that only contains data related to the model forward on GPU. It will be transformed from CPU scheduler to GPU model runner. - ForwardBatch is managed by `model_runner.py::ModelRunner`. It contains low-level tensor data. Most of the data consists of GPU tensors. TODO(lmzheng): ModelWorkerBatch seems a bit redundant and we consider removing it in the future. """ import copy import dataclasses import logging import threading from enum import Enum, auto from http import HTTPStatus from itertools import chain from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple, Union import numpy as np import torch import triton import triton.language as tl from sglang.global_config import global_config from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject from sglang.srt.disaggregation.base import BaseKVSender from sglang.srt.disaggregation.decode_schedule_batch_mixin import ( ScheduleBatchDisaggregationDecodeMixin, ) from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank from sglang.srt.layers.moe import is_tbo_enabled from sglang.srt.mem_cache.allocator import ( BaseTokenToKVPoolAllocator, SWATokenToKVPoolAllocator, ) from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache from sglang.srt.mem_cache.lora_radix_cache import LoRAKey, LoRARadixCache from sglang.srt.mem_cache.memory_pool import ReqToTokenPool from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache from sglang.srt.metrics.collector import TimeStats from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.server_args import ServerArgs from sglang.srt.utils import flatten_nested_list, support_triton if TYPE_CHECKING: from sglang.srt.configs.model_config import ModelConfig from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput from sglang.srt.speculative.spec_info import SpeculativeAlgorithm INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5 GLOBAL_SERVER_ARGS_KEYS = [ "attention_backend", "mm_attention_backend", "debug_tensor_dump_inject", "debug_tensor_dump_output_folder", "chunked_prefill_size", "device", "disable_chunked_prefix_cache", "disable_flashinfer_cutlass_moe_fp4_allgather", "disable_radix_cache", "enable_dp_lm_head", "enable_flashinfer_allreduce_fusion", "moe_dense_tp_size", "ep_dispatch_algorithm", "ep_num_redundant_experts", "enable_nan_detection", "flashinfer_mla_disable_ragged", "max_micro_batch_size", "disable_shared_experts_fusion", "sampling_backend", "speculative_accept_threshold_single", "speculative_accept_threshold_acc", "torchao_config", "triton_attention_reduce_in_fp32", "num_reserved_decode_tokens", "weight_loader_disable_mmap", "enable_multimodal", "enable_symm_mem", "quantization", "enable_custom_logit_processor", ] # Put some global args for easy access global_server_args_dict = {k: getattr(ServerArgs, k) for k in GLOBAL_SERVER_ARGS_KEYS} logger = logging.getLogger(__name__) class BaseFinishReason: def __init__(self, is_error: bool = False): self.is_error = is_error def to_json(self): raise NotImplementedError() class FINISH_MATCHED_TOKEN(BaseFinishReason): def __init__(self, matched: Union[int, List[int]]): super().__init__() self.matched = matched def to_json(self): return { "type": "stop", # to match OpenAI API's return value "matched": self.matched, } class FINISH_MATCHED_STR(BaseFinishReason): def __init__(self, matched: str): super().__init__() self.matched = matched def to_json(self): return { "type": "stop", # to match OpenAI API's return value "matched": self.matched, } class FINISH_LENGTH(BaseFinishReason): def __init__(self, length: int): super().__init__() self.length = length def to_json(self): return { "type": "length", # to match OpenAI API's return value "length": self.length, } class FINISH_ABORT(BaseFinishReason): def __init__(self, message=None, status_code=None, err_type=None): super().__init__(is_error=True) self.message = message or "Aborted" self.status_code = status_code self.err_type = err_type def to_json(self): return { "type": "abort", "message": self.message, "status_code": self.status_code, "err_type": self.err_type, } class Modality(Enum): IMAGE = auto() MULTI_IMAGES = auto() VIDEO = auto() AUDIO = auto() @staticmethod def from_str(modality_str: str): try: return Modality[modality_str.upper()] except KeyError: raise ValueError( f"Invalid modality string: {modality_str}. Valid modalities are: {[m.name for m in Modality]}" ) @staticmethod def all(): return [Modality.IMAGE, Modality.VIDEO, Modality.AUDIO] @dataclasses.dataclass class MultimodalDataItem: """ One MultimodalDataItem contains all inputs for one modality. For example, if there are 3 images and 1 audio inputs, there will be 2 MultimodalDataItem. One for images and one for audio. We put the common fields first and the model-specific fields in model_specific_data. """ modality: Modality hash: int = None pad_value: int = None offsets: Optional[list] = None # the raw features returned by processor, e.g. pixel_values or audio_features feature: Union[torch.Tensor, np.ndarray] = None # the precomputed embeddings, passed as final encoder embeddings # One and only one of the feature and precomputed_embeddings will be empty precomputed_embeddings: Optional[Union[torch.Tensor, np.ndarray]] = None # Model-specific data stored in a dictionary model_specific_data: dict[str, Any] = dataclasses.field(default_factory=dict) def __getattr__(self, name: str): if ( "model_specific_data" in self.__dict__ and name in self.__dict__["model_specific_data"] ): return self.__dict__["model_specific_data"][name] else: raise AttributeError( f"'{self.__class__.__name__}' object has no attribute '{name}'" ) def __setitem__(self, key: str, value: Any): if key in self.__dict__: self.__dict__[key] = value else: self.model_specific_data[key] = value def set(self, key: str, value: Any): self.__setitem__(key, value) @staticmethod def is_empty_list(l): if l is None: return True return len([item for item in flatten_nested_list(l) if item is not None]) == 0 def set_pad_value(self): """ Set the pad value after first hashing the data """ from sglang.srt.managers.mm_utils import hash_feature if self.hash is None: if self.feature is not None: hashed_feature = self.feature else: hashed_feature = self.precomputed_embeddings self.hash = hash_feature(hashed_feature) assert self.hash is not None self.pad_value = self.hash % (1 << 30) def is_modality(self, modality: Modality) -> bool: return self.modality == modality def is_audio(self): return self.modality == Modality.AUDIO def is_image(self): return self.modality in [Modality.IMAGE, Modality.MULTI_IMAGES] def is_video(self): return self.modality == Modality.VIDEO def is_valid(self) -> bool: return self.is_image() or self.is_video() or self.is_audio() def validate(self): ... # TODO @staticmethod def from_dict(obj: dict): kwargs = dict(obj) modality = kwargs.pop("modality") if isinstance(modality, str): modality = Modality[modality] ret = MultimodalDataItem(modality=modality, **kwargs) ret.validate() return ret def merge(self, other): self.feature += other.feature self.offsets += other.offsets self.hash = hash((self.hash, other.hash)) self.set_pad_value() @dataclasses.dataclass class MultimodalInputs: """The multimodal data related inputs.""" # items of data mm_items: List[MultimodalDataItem] image_pad_len: Optional[list] = None num_image_tokens: Optional[int] = None # image im_token_id: Optional[int] = None im_start_id: Optional[int] = None im_end_id: Optional[int] = None slice_start_id: Optional[int] = None slice_end_id: Optional[int] = None # video video_token_id: Optional[int] = None # audio audio_token_id: Optional[int] = None audio_start_id: Optional[int] = None audio_end_id: Optional[int] = None # QWen2-VL related mrope_positions: Optional[torch.Tensor] = None mrope_position_delta: Optional[torch.Tensor] = None @staticmethod def from_dict(obj: dict): ret = MultimodalInputs( mm_items=obj["mm_items"], ) assert isinstance(ret.mm_items, list) ret.mm_items = [item for item in ret.mm_items if item.is_valid()] for item in ret.mm_items: item.set_pad_value() optional_args = [ "mrope_positions", "mrope_position_delta", "im_token_id", "im_start_id", "im_end_id", "video_token_id", "slice_start_id", "slice_end_id", "audio_start_id", "audio_end_id", "audio_token_id", ] for arg in optional_args: if arg in obj: setattr(ret, arg, obj[arg]) return ret def contains_image_inputs(self) -> bool: return any(item.is_image() for item in self.mm_items) def contains_video_inputs(self) -> bool: return any(item.is_video() for item in self.mm_items) def contains_audio_inputs(self) -> bool: return any(item.is_audio() for item in self.mm_items) def contains_mm_input(self) -> bool: return any(True for item in self.mm_items if item.is_valid()) def merge(self, other: MultimodalInputs): """ merge image inputs when requests are being merged """ # args needed to be merged optional_args = [ "mm_items", "image_pad_len", ] for arg in optional_args: self_arg = getattr(self, arg, None) if self_arg is not None: setattr(self, arg, self_arg + getattr(other, arg)) mrope_positions = self.mrope_positions if mrope_positions is not None: if other.mrope_positions is None: self.mrope_positions = mrope_positions else: self.mrope_positions = torch.cat( [self.mrope_positions, other.mrope_positions], dim=1 ) mrope_position_delta = self.mrope_position_delta if mrope_position_delta is not None: if other.mrope_position_delta is None: self.mrope_position_delta = mrope_position_delta else: self.mrope_position_delta = torch.cat( [self.mrope_position_delta, other.mrope_position_delta], dim=0 ) for key, val in other.__dict__.items(): if "_id" in key: # set token_ids if getattr(self, key, None) is None: setattr(self, key, getattr(other, key, None)) # other args would be kept intact class Req: """The input and output status of a request.""" def __init__( self, rid: str, origin_input_text: str, origin_input_ids: List[int], sampling_params: SamplingParams, return_logprob: bool = False, top_logprobs_num: int = 0, token_ids_logprob: List[int] = None, stream: bool = False, origin_input_ids_unpadded: Optional[Tuple[int]] = None, lora_id: Optional[str] = None, input_embeds: Optional[List[List[float]]] = None, token_type_ids: List[int] = None, session_id: Optional[str] = None, custom_logit_processor: Optional[str] = None, return_hidden_states: bool = False, eos_token_ids: Optional[Set[int]] = None, bootstrap_host: Optional[str] = None, bootstrap_port: Optional[int] = None, bootstrap_room: Optional[int] = None, data_parallel_rank: Optional[int] = None, vocab_size: Optional[int] = None, ): # Input and output info self.rid = rid self.origin_input_text = origin_input_text self.origin_input_ids_unpadded = ( origin_input_ids_unpadded if origin_input_ids_unpadded else origin_input_ids # Before image padding ) self.origin_input_ids = origin_input_ids # Each decode stage's output ids self.output_ids = [] # fill_ids = origin_input_ids + output_ids. Updated if chunked. self.fill_ids = [] self.session_id = session_id self.input_embeds = input_embeds # for corss-endoder model self.token_type_ids = token_type_ids # The length of KV that have been removed in local attention chunked prefill self.evicted_seqlen_local = 0 # Sampling info if isinstance(sampling_params.custom_params, dict): sampling_params = copy.copy(sampling_params) sampling_params.custom_params = sampling_params.custom_params | { "__req__": self } self.sampling_params = sampling_params self.custom_logit_processor = custom_logit_processor self.return_hidden_states = return_hidden_states self.lora_id = lora_id # Memory pool info self.req_pool_idx: Optional[int] = None # Check finish self.tokenizer = None self.finished_reason = None # Whether this request has finished output self.finished_output = None # If we want to abort the request in the middle of the event loop, set this to true # Note: We should never set finished_reason in the middle, the req will get filtered and never respond self.to_abort = False # This carries the error message for `.to_abort` and will be attached to the finished_reason at the end of the event loop self.to_abort_message: str = None self.stream = stream self.eos_token_ids = eos_token_ids self.vocab_size = vocab_size # For incremental decoding # ----- | --------- read_ids -------| # ----- | surr_ids | # xxxxx | xxxxxxxxxxx | xxxxxxxxxxx | # ----- ^ ----------- ^ ----------- ^ # ----- 1 ----------- 2 ----------- 3 # 1: surr_offset # 2: read_offset # 3: last token self.surr_offset = None # Surrounding offset to defeat the cleanup algorithm self.read_offset = None self.decoded_text = "" # For multimodal inputs self.multimodal_inputs: Optional[MultimodalInputs] = None # Prefix info # The indices to kv cache for the shared prefix. self.prefix_indices: torch.Tensor = [] # Number of tokens to run prefill. self.extend_input_len = 0 # The relative logprob_start_len in an extend batch self.extend_logprob_start_len = 0 self.last_node: Any = None self.last_host_node: Any = None self.host_hit_length = 0 # The node to lock until for swa radix tree lock ref self.swa_uuid_for_lock: Optional[int] = None # Whether or not if it is chunked. It increments whenever # it is chunked, and decrement whenever chunked request is # processed. self.is_chunked = 0 # For retraction self.is_retracted = False # Incremental streamining self.send_token_offset: int = 0 self.send_decode_id_offset: int = 0 # TODO (Byron): send_output_token_logprobs_offset and send_decode_id_offset can be different in disaggregation mode # because the decode server does not have the first output token logprobs self.send_output_token_logprobs_offset: int = 0 # Logprobs (arguments) self.return_logprob = return_logprob # Start index to compute logprob from. self.logprob_start_len = 0 self.top_logprobs_num = top_logprobs_num self.token_ids_logprob = token_ids_logprob self.temp_scaled_logprobs = False self.top_p_normalized_logprobs = False # Logprobs (return values) # True means the input logprob has been already sent to detokenizer. self.input_logprob_sent: bool = False self.input_token_logprobs_val: Optional[List[float]] = None self.input_token_logprobs_idx: Optional[List[int]] = None self.input_top_logprobs_val: Optional[List[float]] = None self.input_top_logprobs_idx: Optional[List[int]] = None self.input_token_ids_logprobs_val: Optional[List[float]] = None self.input_token_ids_logprobs_idx: Optional[List[int]] = None # Temporary holder to store input_token_logprobs. self.input_token_logprobs: Optional[List[Tuple[int]]] = None self.temp_input_top_logprobs_val: Optional[List[torch.Tensor]] = None self.temp_input_top_logprobs_idx: Optional[List[int]] = None self.temp_input_token_ids_logprobs_val: Optional[List[float]] = None self.temp_input_token_ids_logprobs_idx: Optional[List[int]] = None if return_logprob: # shape: (bs, 1) self.output_token_logprobs_val = [] self.output_token_logprobs_idx = [] # shape: (bs, k) self.output_top_logprobs_val = [] self.output_top_logprobs_idx = [] self.output_token_ids_logprobs_val = [] self.output_token_ids_logprobs_idx = [] else: self.output_token_logprobs_val = self.output_token_logprobs_idx = ( self.output_top_logprobs_val ) = self.output_top_logprobs_idx = self.output_token_ids_logprobs_val = ( self.output_token_ids_logprobs_idx ) = None self.hidden_states: List[List[float]] = [] self.hidden_states_tensor = None # Note: use tensor instead of list to transfer hidden_states when PD + MTP # Embedding (return values) self.embedding = None # Constrained decoding self.grammar: Optional[BaseGrammarObject] = None self.grammar_wait_ct = 0 # The number of cached tokens that were already cached in the KV cache self.cached_tokens = 0 self.already_computed = 0 # The number of verification forward passes in the speculative decoding. # This is used to compute the average acceptance length per request. self.spec_verify_ct = 0 # For metrics self.time_stats: TimeStats = TimeStats() self.has_log_time_stats: bool = False self.queue_time_start = None self.queue_time_end = None # For disaggregation self.bootstrap_host: str = bootstrap_host self.bootstrap_port: Optional[int] = bootstrap_port self.bootstrap_room: Optional[int] = bootstrap_room self.disagg_kv_sender: Optional[BaseKVSender] = None # For data parallel rank routing self.data_parallel_rank: Optional[int] = data_parallel_rank # the start index of the sent kv cache # We want to send it chunk by chunk for chunked prefill. # After every chunk forward, we do the following: # kv_send(req.input_ids[req.start_send_idx:len(req.fill_ids)]) # start_send_idx = len(req.fill_ids) self.start_send_idx: int = 0 # For overlap schedule, we delay the kv transfer until `process_batch_result_disagg_prefill` rather than `process_prefill_chunk` in non-overlap # This is because kv is not ready in `process_prefill_chunk`. # We use `tmp_end_idx` to store the end index of the kv cache to send. self.tmp_end_idx: int = -1 self.metadata_buffer_index: int = -1 @property def seqlen(self): return len(self.origin_input_ids) + len(self.output_ids) def extend_image_inputs(self, image_inputs): if self.multimodal_inputs is None: self.multimodal_inputs = image_inputs else: self.multimodal_inputs.merge(image_inputs) def finished(self) -> bool: # Whether request reached finished condition return self.finished_reason is not None def init_next_round_input( self, tree_cache: Optional[BasePrefixCache] = None, ): self.fill_ids = self.origin_input_ids + self.output_ids if tree_cache is not None: if isinstance(tree_cache, LoRARadixCache): ( self.prefix_indices, self.last_node, self.last_host_node, self.host_hit_length, ) = tree_cache.match_prefix_with_lora_id( key=LoRAKey( lora_id=self.lora_id, token_ids=self.adjust_max_prefix_ids() ), ) else: ( self.prefix_indices, self.last_node, self.last_host_node, self.host_hit_length, ) = tree_cache.match_prefix( key=self.adjust_max_prefix_ids(), ) self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices) def adjust_max_prefix_ids(self): self.fill_ids = self.origin_input_ids + self.output_ids input_len = len(self.fill_ids) # FIXME: To work around some bugs in logprob computation, we need to ensure each # request has at least one token. Later, we can relax this requirement and use `input_len`. max_prefix_len = input_len - 1 if self.sampling_params.max_new_tokens > 0: # Need at least one token to compute logits max_prefix_len = min(max_prefix_len, input_len - 1) if self.return_logprob: max_prefix_len = min(max_prefix_len, self.logprob_start_len) max_prefix_len = max(max_prefix_len, 0) return self.fill_ids[:max_prefix_len] # Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313 def init_incremental_detokenize(self): first_iter = self.surr_offset is None or self.read_offset is None if first_iter: self.read_offset = len(self.origin_input_ids_unpadded) self.surr_offset = max( self.read_offset - INIT_INCREMENTAL_DETOKENIZATION_OFFSET, 0 ) all_ids = self.origin_input_ids_unpadded + self.output_ids return all_ids[self.surr_offset :], self.read_offset - self.surr_offset def check_finished(self): if self.finished(): return if self.to_abort: self.finished_reason = FINISH_ABORT( message=self.to_abort_message, ) return if len(self.output_ids) >= self.sampling_params.max_new_tokens: self.finished_reason = FINISH_LENGTH( length=self.sampling_params.max_new_tokens ) return if self.grammar is not None: if self.grammar.is_terminated(): self.finished_reason = FINISH_MATCHED_TOKEN(matched=self.output_ids[-1]) return last_token_id = self.output_ids[-1] if not self.sampling_params.ignore_eos: matched_eos = False # Check stop token ids if self.sampling_params.stop_token_ids: matched_eos = last_token_id in self.sampling_params.stop_token_ids if self.eos_token_ids: matched_eos |= last_token_id in self.eos_token_ids if self.tokenizer is not None: matched_eos |= last_token_id == self.tokenizer.eos_token_id if self.tokenizer.additional_stop_token_ids: matched_eos |= ( last_token_id in self.tokenizer.additional_stop_token_ids ) if matched_eos: self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id) return if last_token_id > self.vocab_size or last_token_id < 0: if self.sampling_params.stop_token_ids: self.output_ids[-1] = next(iter(self.sampling_params.stop_token_ids)) if self.eos_token_ids: self.output_ids[-1] = next(iter(self.eos_token_ids)) self.finished_reason = FINISH_MATCHED_STR(matched="NaN happened") return # Check stop strings if len(self.sampling_params.stop_strs) > 0: tail_str = self.tokenizer.decode( self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :] ) for stop_str in self.sampling_params.stop_strs: if stop_str in tail_str or stop_str in self.decoded_text: self.finished_reason = FINISH_MATCHED_STR(matched=stop_str) return def reset_for_retract(self): self.prefix_indices = [] self.last_node = None self.swa_uuid_for_lock = None self.extend_input_len = 0 self.is_retracted = True self.input_token_logprobs = None self.temp_input_top_logprobs_val = None self.temp_input_top_logprobs_idx = None self.extend_logprob_start_len = 0 self.is_chunked = 0 self.req_pool_idx = None self.already_computed = 0 def offload_kv_cache(self, req_to_token_pool, token_to_kv_pool_allocator): token_indices = req_to_token_pool.req_to_token[ self.req_pool_idx, : self.seqlen - 1 ] self.kv_cache_cpu = token_to_kv_pool_allocator.get_cpu_copy(token_indices) def load_kv_cache(self, req_to_token_pool, token_to_kv_pool_allocator): token_indices = req_to_token_pool.req_to_token[ self.req_pool_idx, : self.seqlen - 1 ] token_to_kv_pool_allocator.load_cpu_copy(self.kv_cache_cpu, token_indices) del self.kv_cache_cpu def log_time_stats(self): # If overlap schedule, we schedule one decode batch ahead so this gets called twice. if self.has_log_time_stats is True: return if self.bootstrap_room is not None: prefix = f"Req Time Stats(rid={self.rid}, bootstrap_room={self.bootstrap_room}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.get_type().value})" else: prefix = f"Req Time Stats(rid={self.rid}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.get_type().value})" logger.info(f"{prefix}: {self.time_stats}") self.has_log_time_stats = True def set_finish_with_abort(self, error_msg: str): if get_tensor_model_parallel_rank() == 0: logger.error(f"{error_msg}, {self.rid=}") self.multimodal_inputs = None self.grammar = None self.origin_input_ids = [0] # set it to one token to skip the long prefill self.return_logprob = False self.finished_reason = FINISH_ABORT( error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError" ) def __repr__(self): return ( f"Req(rid={self.rid}, " f"input_ids={self.origin_input_ids}, output_ids={self.output_ids}, " f"{self.grammar=}, " f"{self.sampling_params=})" ) # Batch id bid = 0 @dataclasses.dataclass class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): """Store all information of a batch on the scheduler.""" # Request, memory pool, and cache reqs: List[Req] req_to_token_pool: ReqToTokenPool = None token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator = None tree_cache: BasePrefixCache = None is_hybrid: bool = False # Batch configs model_config: ModelConfig = None forward_mode: ForwardMode = None enable_overlap: bool = False # Tell whether the current running batch is full so that we can skip # the check of whether to prefill new requests. # This is an optimization to reduce the overhead of the prefill check. batch_is_full: bool = False # Events launch_done: Optional[threading.Event] = None # For chunked prefill in PP chunked_req: Optional[Req] = None # Sampling info sampling_info: SamplingBatchInfo = None next_batch_sampling_info: SamplingBatchInfo = None # Batched arguments to model runner input_ids: torch.Tensor = None # shape: [b], int64 input_embeds: torch.Tensor = None # shape: [b, hidden_size], float32 token_type_ids: torch.Tensor = None # shape: [b], int64 req_pool_indices: torch.Tensor = None # shape: [b], int64 seq_lens: torch.Tensor = None # shape: [b], int64 # The output locations of the KV cache out_cache_loc: torch.Tensor = None # shape: [b], int64 output_ids: torch.Tensor = None # shape: [b], int64 # For multimodal inputs multimodal_inputs: Optional[List] = None # The sum of all sequence lengths seq_lens_sum: int = None # The original sequence lengths, Qwen-1M related orig_seq_lens: torch.Tensor = None # shape: [b], int32 # For DP attention global_num_tokens: Optional[List[int]] = None global_num_tokens_for_logprob: Optional[List[int]] = None is_extend_in_batch: bool = False can_run_dp_cuda_graph: bool = False tbo_split_seq_index: Optional[int] = None global_forward_mode: Optional[ForwardMode] = None # For processing logprobs return_logprob: bool = False top_logprobs_nums: Optional[List[int]] = None token_ids_logprobs: Optional[List[List[int]]] = None # For logits and logprob post processing temp_scaled_logprobs: bool = False top_p_normalized_logprobs: bool = False # For extend and mixed chunekd prefill prefix_lens: List[int] = None extend_lens: List[int] = None extend_num_tokens: Optional[int] = None decoding_reqs: List[Req] = None extend_logprob_start_lens: List[int] = None # It comes empty list if logprob is not required. extend_input_logprob_token_ids: Optional[torch.Tensor] = None # For encoder-decoder architectures encoder_cached: Optional[List[bool]] = None encoder_lens: Optional[torch.Tensor] = None encoder_lens_cpu: Optional[List[int]] = None encoder_out_cache_loc: Optional[torch.Tensor] = None # Stream has_stream: bool = False # Has grammar has_grammar: bool = False # Device device: str = "cuda" # Speculative decoding spec_algorithm: SpeculativeAlgorithm = None spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None # Whether to return hidden states return_hidden_states: bool = False # Whether this batch is prefill-only (no token generation needed) is_prefill_only: bool = False # hicache pointer for synchronizing data loading from CPU to GPU hicache_consumer_index: int = 0 @classmethod def init_new( cls, reqs: List[Req], req_to_token_pool: ReqToTokenPool, token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator, tree_cache: BasePrefixCache, model_config: ModelConfig, enable_overlap: bool, spec_algorithm: SpeculativeAlgorithm, chunked_req: Optional[Req] = None, ): return_logprob = any(req.return_logprob for req in reqs) is_hybrid = False if isinstance(token_to_kv_pool_allocator, SWATokenToKVPoolAllocator): assert ( tree_cache is None or isinstance(tree_cache, SWARadixCache) or isinstance(tree_cache, SWAChunkCache) ), "SWARadixCache or SWAChunkCache is required for SWATokenToKVPoolAllocator" is_hybrid = True return cls( reqs=reqs, req_to_token_pool=req_to_token_pool, token_to_kv_pool_allocator=token_to_kv_pool_allocator, tree_cache=tree_cache, is_hybrid=is_hybrid, model_config=model_config, enable_overlap=enable_overlap, return_logprob=return_logprob, has_stream=any(req.stream for req in reqs), has_grammar=any(req.grammar for req in reqs), device=req_to_token_pool.device, spec_algorithm=spec_algorithm, return_hidden_states=any(req.return_hidden_states for req in reqs), is_prefill_only=all( req.sampling_params.max_new_tokens == 0 for req in reqs ), chunked_req=chunked_req, ) def batch_size(self): return len(self.reqs) def is_empty(self): return len(self.reqs) == 0 def alloc_req_slots(self, num_reqs: int): req_pool_indices = self.req_to_token_pool.alloc(num_reqs) if req_pool_indices is None: raise RuntimeError( "alloc_req_slots runs out of memory. " "Please set a smaller number for `--max-running-requests`. " f"{self.req_to_token_pool.available_size()=}, " f"{num_reqs=}, " ) return req_pool_indices def alloc_token_slots(self, num_tokens: int, backup_state: bool = False): self._evict_tree_cache_if_needed(num_tokens) if backup_state: state = self.token_to_kv_pool_allocator.backup_state() out_cache_loc = self.token_to_kv_pool_allocator.alloc(num_tokens) if out_cache_loc is None: phase_str = "Prefill" if self.forward_mode.is_extend() else "Decode" error_msg = ( f"{phase_str} out of memory. Try to lower your batch size.\n" f"Try to allocate {num_tokens} tokens.\n" f"{self._available_and_evictable_str()}" ) logger.error(error_msg) if self.tree_cache is not None: self.tree_cache.pretty_print() raise RuntimeError(error_msg) if backup_state: return out_cache_loc, state else: return out_cache_loc def alloc_paged_token_slots_extend( self, prefix_lens: torch.Tensor, seq_lens: torch.Tensor, last_loc: torch.Tensor, extend_num_tokens: int, backup_state: bool = False, ): # Over estimate the number of tokens: assume each request needs a new page. num_tokens = ( extend_num_tokens + len(seq_lens) * self.token_to_kv_pool_allocator.page_size ) self._evict_tree_cache_if_needed(num_tokens) if backup_state: state = self.token_to_kv_pool_allocator.backup_state() out_cache_loc = self.token_to_kv_pool_allocator.alloc_extend( prefix_lens, seq_lens, last_loc, extend_num_tokens ) if out_cache_loc is None: error_msg = ( f"Prefill out of memory. Try to lower your batch size.\n" f"Try to allocate {extend_num_tokens} tokens.\n" f"{self._available_and_evictable_str()}" ) logger.error(error_msg) raise RuntimeError(error_msg) if backup_state: return out_cache_loc, state else: return out_cache_loc def alloc_paged_token_slots_decode( self, seq_lens: torch.Tensor, last_loc: torch.Tensor, backup_state: bool = False, ): # Over estimate the number of tokens: assume each request needs a new page. num_tokens = len(seq_lens) * self.token_to_kv_pool_allocator.page_size self._evict_tree_cache_if_needed(num_tokens) if backup_state: state = self.token_to_kv_pool_allocator.backup_state() out_cache_loc = self.token_to_kv_pool_allocator.alloc_decode(seq_lens, last_loc) if out_cache_loc is None: error_msg = ( f"Decode out of memory. Try to lower your batch size.\n" f"Try to allocate {len(seq_lens)} tokens.\n" f"{self._available_and_evictable_str()}" ) logger.error(error_msg) raise RuntimeError(error_msg) if backup_state: return out_cache_loc, state else: return out_cache_loc def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]): self.encoder_lens_cpu = [] self.encoder_cached = [] for req in self.reqs: im = req.multimodal_inputs if im is None or im.num_image_tokens is None: # No image input self.encoder_lens_cpu.append(0) self.encoder_cached.append(True) else: self.encoder_lens_cpu.append(im.num_image_tokens) self.encoder_cached.append( self.forward_mode.is_decode() or len(req.prefix_indices) >= im.num_image_tokens ) self.encoder_lens = torch.tensor(self.encoder_lens_cpu, dtype=torch.int64).to( self.device, non_blocking=True ) # Strip encoder infos pt = 0 decoder_out_cache_loc = [] encoder_out_cache_loc = [] for i, req in enumerate(self.reqs): encoder_len = self.encoder_lens_cpu[i] seq_lens[i] -= encoder_len if len(req.prefix_indices) < encoder_len: # NOTE: the encoder part should be considered as a whole assert len(req.prefix_indices) == 0 input_ids[i] = input_ids[i][encoder_len:] encoder_out_cache_loc.append(self.out_cache_loc[pt : pt + encoder_len]) decoder_out_cache_loc.append( self.out_cache_loc[pt + encoder_len : pt + req.extend_input_len] ) self.extend_lens[i] -= encoder_len self.extend_num_tokens -= encoder_len else: decoder_out_cache_loc.append( self.out_cache_loc[pt : pt + req.extend_input_len] ) self.prefix_lens[i] -= encoder_len pt += req.extend_input_len # Reassign self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int64).to( self.device, non_blocking=True ) self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64).to( self.device, non_blocking=True ) if not decoder_out_cache_loc: self.out_cache_loc = torch.zeros(0, dtype=torch.int64).to( self.device, non_blocking=True ) else: self.out_cache_loc = torch.cat(decoder_out_cache_loc) if not encoder_out_cache_loc: self.encoder_out_cache_loc = torch.zeros(0, dtype=torch.int64).to( self.device, non_blocking=True ) else: self.encoder_out_cache_loc = torch.cat(encoder_out_cache_loc) assert ( len(self.out_cache_loc) == self.extend_num_tokens ), f"Expected {len(self.out_cache_loc)}, got {self.extend_num_tokens}" def prepare_for_extend(self): self.forward_mode = ForwardMode.EXTEND # Allocate req slots bs = len(self.reqs) req_pool_indices = self.alloc_req_slots(bs) # Init tensors reqs = self.reqs input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs] extend_num_tokens = sum(len(ids) for ids in input_ids) seq_lens = [len(r.fill_ids) for r in reqs] orig_seq_lens = [max(len(r.fill_ids), len(r.origin_input_ids)) for r in reqs] prefix_lens = [len(r.prefix_indices) for r in reqs] extend_lens = [r.extend_input_len for r in reqs] token_type_ids = [ r.token_type_ids for r in reqs if r.token_type_ids is not None ] req_pool_indices_tensor = torch.tensor(req_pool_indices, dtype=torch.int64).to( self.device, non_blocking=True ) input_ids_tensor = torch.tensor( list(chain.from_iterable(input_ids)), dtype=torch.int64 ).to(self.device, non_blocking=True) seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int64).to( self.device, non_blocking=True ) orig_seq_lens_tensor = torch.tensor(orig_seq_lens, dtype=torch.int32).to( self.device, non_blocking=True ) prefix_lens_tensor = torch.tensor( prefix_lens, dtype=torch.int64, device=self.device ) token_type_ids_tensor = None if len(token_type_ids) > 0: token_type_ids_tensor = torch.tensor( sum(token_type_ids, []), dtype=torch.int64 ).to(self.device, non_blocking=True) extend_lens_tensor = seq_lens_tensor - prefix_lens_tensor # Copy prefix and do some basic check input_embeds = [] extend_input_logprob_token_ids = [] multimodal_inputs = [] for i, (req, seq_len, pre_len) in enumerate(zip(reqs, seq_lens, prefix_lens)): req.req_pool_idx = req_pool_indices[i] assert seq_len - pre_len == req.extend_input_len if pre_len > 0: self.req_to_token_pool.write( (req.req_pool_idx, slice(0, pre_len)), req.prefix_indices ) if isinstance(self.tree_cache, SWAChunkCache): self.tree_cache.evict_swa( req, pre_len, self.model_config.attention_chunk_size ) # If input_embeds are available, store them if req.input_embeds is not None: # If req.input_embeds is already a list, append its content directly input_embeds.extend(req.input_embeds) # Use extend to avoid nesting multimodal_inputs.append(req.multimodal_inputs) req.cached_tokens += pre_len - req.already_computed req.already_computed = seq_len req.is_retracted = False # Compute the relative logprob_start_len in an extend batch if req.logprob_start_len >= pre_len: req.extend_logprob_start_len = min( req.logprob_start_len - pre_len, req.extend_input_len, req.seqlen - 1, ) else: req.extend_logprob_start_len = 0 if self.return_logprob: # Find input logprob token ids. # First, find a global index within origin_input_ids and slide it by 1 # to compute input logprobs. It is because you need the next token # to compute input logprobs. E.g., (chunk size 2) # # input_logprobs = [1, 2, 3, 4] # fill_ids = [1, 2] # extend_input_logprob_token_id = [2, 3] # # Note that it can also overflow. In this case, we pad it with 0. # input_logprobs = [1, 2, 3, 4] # fill_ids = [3, 4] # extend_input_logprob_token_id = [4, 0] global_start_idx, global_end_idx = ( len(req.prefix_indices), len(req.fill_ids), ) # Apply logprob_start_len if global_start_idx < req.logprob_start_len: global_start_idx = req.logprob_start_len logprob_token_ids = req.origin_input_ids[ global_start_idx + 1 : global_end_idx + 1 ] extend_input_logprob_token_ids.extend(logprob_token_ids) # We will need req.extend_input_len - req.extend_logprob_start_len number of # tokens, and logprob_token_ids is for input logprob, so pad the rest of them by 0. extend_input_logprob_token_ids.extend( [0] * ( req.extend_input_len - req.extend_logprob_start_len - len(logprob_token_ids) ) ) if self.return_logprob: extend_input_logprob_token_ids = torch.tensor( extend_input_logprob_token_ids ) else: extend_input_logprob_token_ids = None # Allocate memory if self.token_to_kv_pool_allocator.page_size == 1: out_cache_loc = self.alloc_token_slots(extend_num_tokens) else: last_loc = get_last_loc( self.req_to_token_pool.req_to_token, req_pool_indices_tensor, prefix_lens_tensor, ) out_cache_loc = self.alloc_paged_token_slots_extend( prefix_lens_tensor, seq_lens_tensor, last_loc, extend_num_tokens ) # Set fields self.input_ids = input_ids_tensor self.req_pool_indices = req_pool_indices_tensor self.seq_lens = seq_lens_tensor self.orig_seq_lens = orig_seq_lens_tensor self.out_cache_loc = out_cache_loc self.input_embeds = ( torch.tensor(input_embeds).to(self.device, non_blocking=True) if input_embeds else None ) for mm_input in multimodal_inputs: if mm_input is None: continue for mm_item in mm_input.mm_items: pixel_values = getattr(mm_item, "feature", None) if isinstance(pixel_values, torch.Tensor): mm_item.feature = pixel_values.to(self.device, non_blocking=True) self.multimodal_inputs = multimodal_inputs self.token_type_ids = token_type_ids_tensor self.seq_lens_sum = sum(seq_lens) if self.return_logprob: self.top_logprobs_nums = [r.top_logprobs_num for r in reqs] self.token_ids_logprobs = [r.token_ids_logprob for r in reqs] self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs] self.extend_num_tokens = extend_num_tokens self.prefix_lens = prefix_lens self.extend_lens = extend_lens self.extend_input_logprob_token_ids = extend_input_logprob_token_ids # Write to req_to_token_pool if support_triton(global_server_args_dict.get("attention_backend")): # TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start) write_req_to_token_pool_triton[(bs,)]( self.req_to_token_pool.req_to_token, req_pool_indices_tensor, prefix_lens_tensor, seq_lens_tensor, extend_lens_tensor, out_cache_loc, self.req_to_token_pool.req_to_token.shape[1], ) else: pt = 0 for i in range(bs): self.req_to_token_pool.write( (req_pool_indices[i], slice(prefix_lens[i], seq_lens[i])), out_cache_loc[pt : pt + extend_lens[i]], ) pt += extend_lens[i] if self.model_config.is_encoder_decoder: self.prepare_encoder_info_extend(input_ids, seq_lens) # Build sampling info self.sampling_info = SamplingBatchInfo.from_schedule_batch( self, self.model_config.vocab_size, ) def prepare_for_split_prefill(self): self.prepare_for_extend() # For split prefill, we need to set the forward mode to SPLIT_PREFILL self.forward_mode = ForwardMode.SPLIT_PREFILL def mix_with_running(self, running_batch: "ScheduleBatch"): self.forward_mode = ForwardMode.MIXED running_bs = running_batch.batch_size() for req in running_batch.reqs: req.fill_ids = req.origin_input_ids + req.output_ids req.extend_input_len = 1 input_ids = torch.cat([self.input_ids, running_batch.input_ids]) out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc]) self.merge_batch(running_batch) self.input_ids = input_ids self.out_cache_loc = out_cache_loc # For overlap scheduler, the output_ids has one step delay delta = 0 if self.enable_overlap else -1 # NOTE: prefix_indices is what has been cached, but we don't cache each decode step self.prefix_lens.extend( [ len(r.origin_input_ids) + len(r.output_ids) + delta for r in running_batch.reqs ] ) self.extend_lens.extend([1] * running_bs) self.extend_num_tokens += running_bs # TODO (lianmin): Revisit this. It should be seq_len - 1 self.extend_logprob_start_lens.extend([0] * running_bs) def new_page_count_next_decode(self): page_size = self.token_to_kv_pool_allocator.page_size if page_size == 1: return len(self.reqs) # In the decoding phase, the length of a request's KV cache should be # the total length of the request minus 1 return ( sum(1 for req in self.reqs if req.seqlen % page_size == 0) if self.enable_overlap else sum(1 for req in self.reqs if (req.seqlen - 1) % page_size == 0) ) def check_decode_mem(self, buf_multiplier=1): num_tokens = ( self.new_page_count_next_decode() * buf_multiplier * self.token_to_kv_pool_allocator.page_size ) self._evict_tree_cache_if_needed(num_tokens) return self._is_available_size_sufficient(num_tokens) def retract_decode(self, server_args: ServerArgs): """Retract the decoding requests when there is not enough memory.""" sorted_indices = list(range(len(self.reqs))) # TODO(lsyin): improve retraction policy for radix cache # For spec decoding, filter_batch API can only filter # requests from the back, so we can only retract from the back. # TODO(sang): Clean up finish path and support better retract # policy. if not server_args.speculative_algorithm: sorted_indices.sort( key=lambda i: ( len(self.reqs[i].output_ids), -len(self.reqs[i].origin_input_ids), ), reverse=True, ) def get_required_tokens(num_reqs: int): headroom_for_spec_decode = 0 if server_args.speculative_algorithm: headroom_for_spec_decode += ( num_reqs * server_args.speculative_eagle_topk * server_args.speculative_num_steps + num_reqs * server_args.speculative_num_draft_tokens ) return ( num_reqs * global_config.retract_decode_steps + headroom_for_spec_decode ) def _get_available_size(): if self.is_hybrid: return min( self.token_to_kv_pool_allocator.full_available_size(), self.token_to_kv_pool_allocator.swa_available_size(), ) else: return self.token_to_kv_pool_allocator.available_size() retracted_reqs = [] seq_lens_cpu = self.seq_lens.cpu().numpy() first_iter = True while ( _get_available_size() < get_required_tokens(len(sorted_indices)) or first_iter ): if len(sorted_indices) == 1: # Corner case: only one request left if self.is_hybrid: full_available_size = ( self.token_to_kv_pool_allocator.full_available_size() ) swa_available_size = ( self.token_to_kv_pool_allocator.swa_available_size() ) assert ( full_available_size > 0 and swa_available_size > 0 ), f"No space left for only one request in SWA mode {full_available_size=}, {swa_available_size=}" else: assert ( self.token_to_kv_pool_allocator.available_size() > 0 ), f"No space left for only one request, {self.token_to_kv_pool_allocator.available_size()=}" break first_iter = False idx = sorted_indices.pop() req = self.reqs[idx] retracted_reqs.append(req) if server_args.disaggregation_mode == "decode": req.offload_kv_cache( self.req_to_token_pool, self.token_to_kv_pool_allocator ) if isinstance(self.tree_cache, ChunkCache): # ChunkCache does not have eviction token_indices = self.req_to_token_pool.req_to_token[ req.req_pool_idx, : seq_lens_cpu[idx] ] self.token_to_kv_pool_allocator.free(token_indices) self.req_to_token_pool.free(req.req_pool_idx) else: # TODO: apply more fine-grained retraction last_uncached_pos = ( len(req.prefix_indices) // server_args.page_size ) * server_args.page_size token_indices = self.req_to_token_pool.req_to_token[ req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx] ] self.token_to_kv_pool_allocator.free(token_indices) self.req_to_token_pool.free(req.req_pool_idx) # release the last node if self.is_hybrid: self.tree_cache.dec_lock_ref(req.last_node, req.swa_uuid_for_lock) else: self.tree_cache.dec_lock_ref(req.last_node) # NOTE(lsyin): we should use the newly evictable memory instantly. num_tokens = len(sorted_indices) * global_config.retract_decode_steps self._evict_tree_cache_if_needed(num_tokens) req.reset_for_retract() if len(retracted_reqs) == 0: # Corner case: only one request left raise ValueError( "Failed to retract any request. No space left for only one request." ) self.filter_batch(keep_indices=sorted_indices) # Reqs in batch are filtered total_decoded_tokens = sum(len(r.output_ids) for r in self.reqs) total_max_new_tokens = sum(r.sampling_params.max_new_tokens for r in self.reqs) new_estimate_ratio = ( total_decoded_tokens + global_config.retract_decode_steps * len(self.reqs) ) / total_max_new_tokens new_estimate_ratio = min(1.0, new_estimate_ratio) return retracted_reqs, new_estimate_ratio def prepare_encoder_info_decode(self): # Reset the encoder cached status self.encoder_cached = [True] * len(self.reqs) def prepare_for_idle(self): self.forward_mode = ForwardMode.IDLE self.input_ids = torch.empty(0, dtype=torch.int64, device=self.device) self.seq_lens = torch.empty(0, dtype=torch.int64, device=self.device) self.orig_seq_lens = torch.empty(0, dtype=torch.int32, device=self.device) self.out_cache_loc = torch.empty(0, dtype=torch.int64, device=self.device) self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device) self.seq_lens_sum = 0 self.extend_num_tokens = 0 self.sampling_info = SamplingBatchInfo.from_schedule_batch( self, self.model_config.vocab_size, ) def prepare_for_decode(self): self.forward_mode = ForwardMode.DECODE bs = len(self.reqs) if self.spec_algorithm.is_eagle(): # if spec decoding is used, the decode batch is prepared inside # `forward_batch_speculative_generation` after running draft models. return if self.sampling_info.penalizer_orchestrator.is_required: if self.enable_overlap: # TODO: this can be slow, optimize this. delayed_output_ids = torch.tensor( [ ( req.output_ids[-1] if len(req.output_ids) else req.origin_input_ids[-1] ) for req in self.reqs ], dtype=torch.int64, device=self.device, ) self.sampling_info.penalizer_orchestrator.cumulate_output_tokens( delayed_output_ids ) else: self.sampling_info.penalizer_orchestrator.cumulate_output_tokens( self.output_ids.to(torch.int64) ) # Update fields self.input_ids = self.output_ids self.output_ids = None if self.model_config.is_encoder_decoder: locs = self.encoder_lens + self.seq_lens self.prepare_encoder_info_decode() else: locs = self.seq_lens.clone() if self.enable_overlap: # Do not use in-place operations in the overlap mode self.seq_lens = self.seq_lens + 1 self.orig_seq_lens = self.orig_seq_lens + 1 else: # A faster in-place version self.seq_lens.add_(1) self.orig_seq_lens.add_(1) self.seq_lens_sum += bs # free memory if isinstance(self.tree_cache, SWAChunkCache): for req in self.reqs: self.tree_cache.evict_swa( req, req.seqlen - 1, self.model_config.attention_chunk_size ) # Allocate memory if self.token_to_kv_pool_allocator.page_size == 1: self.out_cache_loc = self.alloc_token_slots(bs) else: last_loc = self.req_to_token_pool.req_to_token[ self.req_pool_indices, self.seq_lens - 2 ] self.out_cache_loc = self.alloc_paged_token_slots_decode( self.seq_lens, last_loc ) self.req_to_token_pool.write( (self.req_pool_indices, locs), self.out_cache_loc.to(torch.int32) ) def filter_batch( self, chunked_req_to_exclude: Optional[Union[Req, List[Req]]] = None, keep_indices: Optional[List[int]] = None, ): if keep_indices is None: if isinstance(chunked_req_to_exclude, Req): chunked_req_to_exclude = [chunked_req_to_exclude] elif chunked_req_to_exclude is None: chunked_req_to_exclude = [] keep_indices = [ i for i in range(len(self.reqs)) if not self.reqs[i].finished() and self.reqs[i] not in chunked_req_to_exclude ] if keep_indices is None or len(keep_indices) == 0: # Filter out all requests self.reqs = [] return if len(keep_indices) == len(self.reqs): # No need to filter return keep_indices_device = torch.tensor(keep_indices, dtype=torch.int64).to( self.device, non_blocking=True ) if self.model_config.is_encoder_decoder: self.encoder_lens = self.encoder_lens[keep_indices_device] self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices] self.reqs = [self.reqs[i] for i in keep_indices] if self.multimodal_inputs is not None: self.multimodal_inputs = [self.multimodal_inputs[i] for i in keep_indices] self.req_pool_indices = self.req_pool_indices[keep_indices_device] self.seq_lens = self.seq_lens[keep_indices_device] self.orig_seq_lens = self.orig_seq_lens[keep_indices_device] self.out_cache_loc = None self.seq_lens_sum = self.seq_lens.sum().item() self.output_ids = self.output_ids[keep_indices_device] self.return_logprob = any(req.return_logprob for req in self.reqs) if self.return_logprob: self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in keep_indices] self.token_ids_logprobs = [self.token_ids_logprobs[i] for i in keep_indices] else: self.top_logprobs_nums = None self.token_ids_logprobs = None self.has_stream = any(req.stream for req in self.reqs) self.has_grammar = any(req.grammar for req in self.reqs) self.sampling_info.filter_batch(keep_indices, keep_indices_device) if self.spec_info: self.spec_info.filter_batch(keep_indices_device) def merge_batch(self, other: "ScheduleBatch"): # Penalizer orchestrator must be merged before Batch.reqs is merged. This is because # orchestrator.merge() depends on Batch.reqs during preparation of each penalizers, so it # needs to be called with pre-merged Batch.reqs. self.sampling_info.merge_batch(other.sampling_info) # Encoder-decoder infos if self.model_config.is_encoder_decoder: self.encoder_lens = torch.cat([self.encoder_lens, other.encoder_lens]) self.encoder_lens_cpu.extend(other.encoder_lens_cpu) self.req_pool_indices = torch.cat( [self.req_pool_indices, other.req_pool_indices] ) self.seq_lens = torch.cat([self.seq_lens, other.seq_lens]) self.orig_seq_lens = torch.cat([self.orig_seq_lens, other.orig_seq_lens]) self.out_cache_loc = None self.seq_lens_sum += other.seq_lens_sum if self.output_ids is not None: self.output_ids = torch.cat([self.output_ids, other.output_ids]) if self.return_logprob and other.return_logprob: self.top_logprobs_nums.extend(other.top_logprobs_nums) self.token_ids_logprobs.extend(other.token_ids_logprobs) elif self.return_logprob: self.top_logprobs_nums.extend([0] * len(other.reqs)) self.token_ids_logprobs.extend([None] * len(other.reqs)) elif other.return_logprob: self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums self.token_ids_logprobs = [None] * len(self.reqs) + other.token_ids_logprobs self.reqs.extend(other.reqs) if self.multimodal_inputs is not None: self.multimodal_inputs.extend(other.multimodal_inputs) self.return_logprob |= other.return_logprob self.has_stream |= other.has_stream self.has_grammar |= other.has_grammar self.return_hidden_states |= other.return_hidden_states if self.spec_info: self.spec_info.merge_batch(other.spec_info) def get_model_worker_batch( self, seq_lens_cpu_cache: Optional[torch.Tensor] = None ) -> ModelWorkerBatch: if self.forward_mode.is_decode_or_idle(): extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None else: extend_seq_lens = self.extend_lens extend_prefix_lens = self.prefix_lens extend_logprob_start_lens = self.extend_logprob_start_lens if self.sampling_info: if self.has_grammar: self.sampling_info.grammars = [req.grammar for req in self.reqs] else: self.sampling_info.grammars = None seq_lens_cpu = ( seq_lens_cpu_cache if seq_lens_cpu_cache is not None else self.seq_lens.cpu() ) global bid bid += 1 return ModelWorkerBatch( bid=bid, forward_mode=self.forward_mode, input_ids=self.input_ids, req_pool_indices=self.req_pool_indices, seq_lens=self.seq_lens, orig_seq_lens=self.orig_seq_lens, out_cache_loc=self.out_cache_loc, seq_lens_cpu=seq_lens_cpu, seq_lens_sum=self.seq_lens_sum, return_logprob=self.return_logprob, top_logprobs_nums=self.top_logprobs_nums, token_ids_logprobs=self.token_ids_logprobs, global_num_tokens=self.global_num_tokens, global_num_tokens_for_logprob=self.global_num_tokens_for_logprob, is_extend_in_batch=self.is_extend_in_batch, can_run_dp_cuda_graph=self.can_run_dp_cuda_graph, tbo_split_seq_index=self.tbo_split_seq_index, global_forward_mode=self.global_forward_mode, extend_num_tokens=self.extend_num_tokens, extend_seq_lens=extend_seq_lens, extend_prefix_lens=extend_prefix_lens, extend_logprob_start_lens=extend_logprob_start_lens, multimodal_inputs=self.multimodal_inputs, encoder_cached=self.encoder_cached, encoder_lens=self.encoder_lens, encoder_lens_cpu=self.encoder_lens_cpu, encoder_out_cache_loc=self.encoder_out_cache_loc, lora_ids=[req.lora_id for req in self.reqs], sampling_info=self.sampling_info, input_embeds=self.input_embeds, token_type_ids=self.token_type_ids, spec_algorithm=self.spec_algorithm, spec_info=self.spec_info, hicache_consumer_index=self.hicache_consumer_index, capture_hidden_mode=( CaptureHiddenMode.FULL if self.return_hidden_states else ( getattr( self.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL ) if self.spec_info else CaptureHiddenMode.NULL ) ), extend_input_logprob_token_ids=self.extend_input_logprob_token_ids, launch_done=self.launch_done, ) def copy(self): # Only contain fields that will be used by process_batch_result return ScheduleBatch( reqs=self.reqs, model_config=self.model_config, forward_mode=self.forward_mode, out_cache_loc=self.out_cache_loc, return_logprob=self.return_logprob, decoding_reqs=self.decoding_reqs, spec_algorithm=self.spec_algorithm, global_num_tokens=self.global_num_tokens, global_num_tokens_for_logprob=self.global_num_tokens_for_logprob, can_run_dp_cuda_graph=self.can_run_dp_cuda_graph, is_extend_in_batch=self.is_extend_in_batch, is_prefill_only=self.is_prefill_only, ) def _evict_tree_cache_if_needed(self, num_tokens: int): if isinstance(self.tree_cache, (SWAChunkCache, ChunkCache)): return if self.is_hybrid: full_available_size = self.token_to_kv_pool_allocator.full_available_size() swa_available_size = self.token_to_kv_pool_allocator.swa_available_size() if full_available_size < num_tokens or swa_available_size < num_tokens: if self.tree_cache is not None: full_num_tokens = max(0, num_tokens - full_available_size) swa_num_tokens = max(0, num_tokens - swa_available_size) self.tree_cache.evict(full_num_tokens, swa_num_tokens) else: if self.token_to_kv_pool_allocator.available_size() < num_tokens: if self.tree_cache is not None: self.tree_cache.evict(num_tokens) def _is_available_size_sufficient(self, num_tokens: int) -> bool: if self.is_hybrid: return ( self.token_to_kv_pool_allocator.full_available_size() >= num_tokens and self.token_to_kv_pool_allocator.swa_available_size() >= num_tokens ) else: return self.token_to_kv_pool_allocator.available_size() >= num_tokens def _available_and_evictable_str(self) -> str: if self.is_hybrid: full_available_size = self.token_to_kv_pool_allocator.full_available_size() swa_available_size = self.token_to_kv_pool_allocator.swa_available_size() full_evictable_size = self.tree_cache.full_evictable_size() swa_evictable_size = self.tree_cache.swa_evictable_size() return ( f"Available full tokens: {full_available_size + full_evictable_size} ({full_available_size=} + {full_evictable_size=})\n" f"Available swa tokens: {swa_available_size + swa_evictable_size} ({swa_available_size=} + {swa_evictable_size=})\n" f"Full LRU list evictable size: {self.tree_cache.full_lru_list_evictable_size()}\n" f"SWA LRU list evictable size: {self.tree_cache.swa_lru_list_evictable_size()}\n" ) else: available_size = self.token_to_kv_pool_allocator.available_size() evictable_size = self.tree_cache.evictable_size() return f"Available tokens: {available_size + evictable_size} ({available_size=} + {evictable_size=})\n" def __str__(self): return ( f"ScheduleBatch(forward_mode={self.forward_mode.name if self.forward_mode else 'None'}, " f"#req={(len(self.reqs))})" ) @dataclasses.dataclass class ModelWorkerBatch: # The batch id bid: int # The forward mode forward_mode: ForwardMode # The input ids input_ids: torch.Tensor # The indices of requests in the req_to_token_pool req_pool_indices: torch.Tensor # The sequence length seq_lens: torch.Tensor # The indices of output tokens in the token_to_kv_pool_allocator out_cache_loc: torch.Tensor # The sequence length tensor on CPU seq_lens_cpu: Optional[torch.Tensor] seq_lens_sum: int # For logprob return_logprob: bool top_logprobs_nums: Optional[List[int]] token_ids_logprobs: Optional[List[List[int]]] # For DP attention global_num_tokens: Optional[List[int]] global_num_tokens_for_logprob: Optional[List[int]] is_extend_in_batch: bool can_run_dp_cuda_graph: bool tbo_split_seq_index: Optional[int] global_forward_mode: Optional[ForwardMode] # For extend extend_num_tokens: Optional[int] extend_seq_lens: Optional[List[int]] extend_prefix_lens: Optional[List[int]] extend_logprob_start_lens: Optional[List[int]] extend_input_logprob_token_ids: Optional[torch.Tensor] # For multimodal multimodal_inputs: Optional[List[MultimodalInputs]] # For encoder-decoder encoder_cached: Optional[List[bool]] encoder_lens: Optional[torch.Tensor] encoder_lens_cpu: Optional[List[int]] encoder_out_cache_loc: Optional[torch.Tensor] # For LoRA lora_ids: Optional[List[str]] # Sampling info sampling_info: SamplingBatchInfo # The original sequence lengths, Qwen-1M related orig_seq_lens: Optional[torch.Tensor] = None # The input Embeds input_embeds: Optional[torch.Tensor] = None # For corss-encoder model token_type_ids: Optional[torch.Tensor] = None # Speculative decoding spec_algorithm: SpeculativeAlgorithm = None spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None # If set, the output of the batch contains the hidden states of the run. capture_hidden_mode: CaptureHiddenMode = None hicache_consumer_index: int = 0 # Overlap event launch_done: Optional[threading.Event] = None @triton.jit def write_req_to_token_pool_triton( req_to_token_ptr, # [max_batch, max_context_len] req_pool_indices, pre_lens, seq_lens, extend_lens, out_cache_loc, req_to_token_ptr_stride: tl.constexpr, ): BLOCK_SIZE: tl.constexpr = 512 pid = tl.program_id(0) req_pool_index = tl.load(req_pool_indices + pid) pre_len = tl.load(pre_lens + pid) seq_len = tl.load(seq_lens + pid) # NOTE: This can be slow for large bs cumsum_start = tl.cast(0, tl.int64) for i in range(pid): cumsum_start += tl.load(extend_lens + i) num_loop = tl.cdiv(seq_len - pre_len, BLOCK_SIZE) for i in range(num_loop): offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE mask = offset < (seq_len - pre_len) value = tl.load(out_cache_loc + cumsum_start + offset, mask=mask) tl.store( req_to_token_ptr + req_pool_index * req_to_token_ptr_stride + offset + pre_len, value, mask=mask, ) def get_last_loc( req_to_token: torch.Tensor, req_pool_indices_tensor: torch.Tensor, prefix_lens_tensor: torch.Tensor, ) -> torch.Tensor: if ( global_server_args_dict["attention_backend"] != "ascend" and global_server_args_dict["attention_backend"] != "torch_native" ): impl = get_last_loc_triton else: impl = get_last_loc_torch return impl(req_to_token, req_pool_indices_tensor, prefix_lens_tensor) def get_last_loc_torch( req_to_token: torch.Tensor, req_pool_indices_tensor: torch.Tensor, prefix_lens_tensor: torch.Tensor, ) -> torch.Tensor: return torch.where( prefix_lens_tensor > 0, req_to_token[req_pool_indices_tensor, prefix_lens_tensor - 1], torch.full_like(prefix_lens_tensor, -1), ) @triton.jit def get_last_loc_kernel( req_to_token, req_pool_indices_tensor, prefix_lens_tensor, result, num_tokens, req_to_token_stride, BLOCK_SIZE: tl.constexpr, ): pid = tl.program_id(0) offset = tl.arange(0, BLOCK_SIZE) + pid * BLOCK_SIZE mask = offset < num_tokens prefix_lens = tl.load(prefix_lens_tensor + offset, mask=mask, other=0) req_pool_indices = tl.load(req_pool_indices_tensor + offset, mask=mask, other=0) token_mask = prefix_lens > 0 token_index = req_pool_indices * req_to_token_stride + (prefix_lens - 1) tokens = tl.load(req_to_token + token_index, mask=token_mask, other=-1) tl.store(result + offset, tokens, mask=mask) def get_last_loc_triton( req_to_token: torch.Tensor, req_pool_indices_tensor: torch.Tensor, prefix_lens_tensor: torch.Tensor, ) -> torch.Tensor: BLOCK_SIZE = 256 num_tokens = prefix_lens_tensor.shape[0] result = torch.empty_like(prefix_lens_tensor) grid = (triton.cdiv(num_tokens, BLOCK_SIZE),) get_last_loc_kernel[grid]( req_to_token, req_pool_indices_tensor, prefix_lens_tensor, result, num_tokens, req_to_token.stride(0), BLOCK_SIZE, ) return result