""" The definition of objects transfered between different processes (TokenizerManager, DetokenizerManager, Controller). """ import uuid from dataclasses import dataclass from typing import Dict, List, Optional, Union from sglang.srt.managers.controller.infer_batch import BaseFinishReason from sglang.srt.sampling_params import SamplingParams @dataclass class GenerateReqInput: # The input prompt. It can be a single prompt or a batch of prompts. text: Optional[Union[List[str], str]] = None # The token ids for text; one can either specify text or input_ids. input_ids: Optional[Union[List[List[int]], List[int]]] = None # The image input. It can be a file name, a url, or base64 encoded string. # See also python/sglang/srt/utils.py:load_image. image_data: Optional[Union[List[str], str]] = None # The sampling_params. sampling_params: Union[List[Dict], Dict] = None # The request id. rid: Optional[Union[List[str], str]] = None # Whether to return logprobs. return_logprob: Optional[Union[List[bool], bool]] = None # The start location of the prompt for return_logprob. logprob_start_len: Optional[Union[List[int], int]] = None # The number of top logprobs to return. top_logprobs_num: Optional[Union[List[int], int]] = None # Whether to detokenize tokens in logprobs. return_text_in_logprobs: bool = False # Whether to stream output. stream: bool = False def post_init(self): if (self.text is None and self.input_ids is None) or ( self.text is not None and self.input_ids is not None ): raise ValueError("Either text or input_ids should be provided.") if isinstance(self.sampling_params, dict) and self.sampling_params.get("n", 1) != 1: is_single = False else: if self.text is not None: is_single = isinstance(self.text, str) else: is_single = isinstance(self.input_ids[0], int) self.is_single = is_single if is_single: if self.sampling_params is None: self.sampling_params = {} if self.rid is None: self.rid = uuid.uuid4().hex if self.return_logprob is None: self.return_logprob = False if self.logprob_start_len is None: self.logprob_start_len = 0 if self.top_logprobs_num is None: self.top_logprobs_num = 0 else: parallel_sample_num = self.sampling_params.get("n", 1) if parallel_sample_num != 1: # parallel sampling +1 represents the original prefill stage num = parallel_sample_num + 1 if isinstance(self.text, List): ## suppot batch operation self.batch_size = len(self.text) num = num * len(self.text) else: self.batch_size = 1 else: ## support select operation num = len(self.text) if self.text is not None else len(self.input_ids) self.batch_size = num if self.image_data is None: self.image_data = [None] * num elif not isinstance(self.image_data, list): self.image_data = [self.image_data] * num if self.sampling_params is None: self.sampling_params = [{}] * num elif not isinstance(self.sampling_params, list): self.sampling_params = [self.sampling_params] * num if self.rid is None: self.rid = [uuid.uuid4().hex for _ in range(num)] else: if not isinstance(self.rid, list): raise ValueError("The rid should be a list.") if self.return_logprob is None: self.return_logprob = [False] * num elif not isinstance(self.return_logprob, list): self.return_logprob = [self.return_logprob] * num if self.logprob_start_len is None: self.logprob_start_len = [0] * num elif not isinstance(self.logprob_start_len, list): self.logprob_start_len = [self.logprob_start_len] * num if self.top_logprobs_num is None: self.top_logprobs_num = [0] * num elif not isinstance(self.top_logprobs_num, list): self.top_logprobs_num = [self.top_logprobs_num] * num @dataclass class TokenizedGenerateReqInput: rid: str input_text: str input_ids: List[int] pixel_values: List[float] image_hash: int image_size: List[int] sampling_params: SamplingParams return_logprob: bool logprob_start_len: int top_logprobs_num: int stream: bool @dataclass class BatchTokenIDOut: rids: List[str] vids: List[int] decoded_texts: List[str] decode_ids: List[int] read_offsets: List[int] skip_special_tokens: List[bool] spaces_between_special_tokens: List[bool] meta_info: List[Dict] finished_reason: List[BaseFinishReason] @dataclass class BatchStrOut: rids: List[str] output_strs: List[str] meta_info: List[Dict] finished_reason: List[BaseFinishReason] @dataclass class FlushCacheReq: pass @dataclass class AbortReq: rid: str @dataclass class DetokenizeReqInput: input_ids: List[int]