Unverified Commit bf72b801 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

[Auto Sync] Update io_struct.py (20250909) (#10236)


Co-authored-by: default avatargithub-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: default avatarcctry <shiyang@x.ai>
parent 8cbe1538
...@@ -246,6 +246,8 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin): ...@@ -246,6 +246,8 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
output_token_ids_logprobs_val=recv_obj.output_token_ids_logprobs_val, output_token_ids_logprobs_val=recv_obj.output_token_ids_logprobs_val,
output_token_ids_logprobs_idx=recv_obj.output_token_ids_logprobs_idx, output_token_ids_logprobs_idx=recv_obj.output_token_ids_logprobs_idx,
output_hidden_states=recv_obj.output_hidden_states, output_hidden_states=recv_obj.output_hidden_states,
placeholder_tokens_idx=None,
placeholder_tokens_val=None,
) )
def handle_multimodal_decode_req(self, recv_obj: BatchMultimodalDecodeReq): def handle_multimodal_decode_req(self, recv_obj: BatchMultimodalDecodeReq):
...@@ -257,6 +259,8 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin): ...@@ -257,6 +259,8 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
prompt_tokens=recv_obj.prompt_tokens, prompt_tokens=recv_obj.prompt_tokens,
completion_tokens=recv_obj.completion_tokens, completion_tokens=recv_obj.completion_tokens,
cached_tokens=recv_obj.cached_tokens, cached_tokens=recv_obj.cached_tokens,
placeholder_tokens_idx=None,
placeholder_tokens_val=None,
) )
def handle_freeze_gc_req(self, recv_req: FreezeGCReq): def handle_freeze_gc_req(self, recv_req: FreezeGCReq):
......
...@@ -121,6 +121,7 @@ class GenerateReqInput: ...@@ -121,6 +121,7 @@ class GenerateReqInput:
bootstrap_host: Optional[Union[List[str], str]] = None bootstrap_host: Optional[Union[List[str], str]] = None
bootstrap_port: Optional[Union[List[Optional[int]], int]] = None bootstrap_port: Optional[Union[List[Optional[int]], int]] = None
bootstrap_room: Optional[Union[List[int], int]] = None bootstrap_room: Optional[Union[List[int], int]] = None
bootstrap_pair_key: Optional[Union[List[str], str]] = None
# For data parallel rank routing # For data parallel rank routing
data_parallel_rank: Optional[int] = None data_parallel_rank: Optional[int] = None
...@@ -128,6 +129,15 @@ class GenerateReqInput: ...@@ -128,6 +129,15 @@ class GenerateReqInput:
# For background responses (OpenAI responses API) # For background responses (OpenAI responses API)
background: bool = False background: bool = False
# Conversation id used for tracking requests
conversation_id: Optional[str] = None
# Label for the request
label: Optional[str] = None
# Image gen grpc migration
return_bytes: bool = False
def contains_mm_input(self) -> bool: def contains_mm_input(self) -> bool:
return ( return (
has_valid_data(self.image_data) has_valid_data(self.image_data)
...@@ -258,6 +268,7 @@ class GenerateReqInput: ...@@ -258,6 +268,7 @@ class GenerateReqInput:
self._normalize_sampling_params(num) self._normalize_sampling_params(num)
self._normalize_logprob_params(num) self._normalize_logprob_params(num)
self._normalize_custom_logit_processor(num) self._normalize_custom_logit_processor(num)
self._normalize_bootstrap_params(num)
def _expand_inputs(self, num): def _expand_inputs(self, num):
"""Expand the main inputs (text, input_ids, input_embeds) for parallel sampling.""" """Expand the main inputs (text, input_ids, input_embeds) for parallel sampling."""
...@@ -297,6 +308,11 @@ class GenerateReqInput: ...@@ -297,6 +308,11 @@ class GenerateReqInput:
self.image_data = [[self.image_data]] * num self.image_data = [[self.image_data]] * num
self.modalities = ["image"] * num self.modalities = ["image"] * num
elif isinstance(self.image_data, list): elif isinstance(self.image_data, list):
# Handle empty list case - treat as no images
if len(self.image_data) == 0:
self.image_data = [None] * num
return
if len(self.image_data) != self.batch_size: if len(self.image_data) != self.batch_size:
raise ValueError( raise ValueError(
"The length of image_data should be equal to the batch size." "The length of image_data should be equal to the batch size."
...@@ -421,6 +437,40 @@ class GenerateReqInput: ...@@ -421,6 +437,40 @@ class GenerateReqInput:
"Cannot use list custom_logit_processor with parallel_sample_num > 1" "Cannot use list custom_logit_processor with parallel_sample_num > 1"
) )
def _normalize_bootstrap_params(self, num):
"""Normalize bootstrap parameters for batch processing."""
# Normalize bootstrap_host
if self.bootstrap_host is None:
self.bootstrap_host = [None] * num
elif not isinstance(self.bootstrap_host, list):
self.bootstrap_host = [self.bootstrap_host] * num
elif isinstance(self.bootstrap_host, list):
self.bootstrap_host = self.bootstrap_host * self.parallel_sample_num
# Normalize bootstrap_port
if self.bootstrap_port is None:
self.bootstrap_port = [None] * num
elif not isinstance(self.bootstrap_port, list):
self.bootstrap_port = [self.bootstrap_port] * num
elif isinstance(self.bootstrap_port, list):
self.bootstrap_port = self.bootstrap_port * self.parallel_sample_num
# Normalize bootstrap_room
if self.bootstrap_room is None:
self.bootstrap_room = [None] * num
elif not isinstance(self.bootstrap_room, list):
self.bootstrap_room = [self.bootstrap_room + i for i in range(num)]
elif isinstance(self.bootstrap_room, list):
self.bootstrap_room = self.bootstrap_room * self.parallel_sample_num
# Normalize bootstrap_pair_key
if self.bootstrap_pair_key is None:
self.bootstrap_pair_key = [None] * num
elif not isinstance(self.bootstrap_pair_key, list):
self.bootstrap_pair_key = [self.bootstrap_pair_key] * num
elif isinstance(self.bootstrap_pair_key, list):
self.bootstrap_pair_key = self.bootstrap_pair_key * self.parallel_sample_num
def _validate_session_params(self): def _validate_session_params(self):
"""Validate that session parameters are properly formatted.""" """Validate that session parameters are properly formatted."""
if self.session_params is not None: if self.session_params is not None:
...@@ -453,7 +503,13 @@ class GenerateReqInput: ...@@ -453,7 +503,13 @@ class GenerateReqInput:
return_text_in_logprobs=self.return_text_in_logprobs, return_text_in_logprobs=self.return_text_in_logprobs,
stream=self.stream, stream=self.stream,
log_metrics=self.log_metrics, log_metrics=self.log_metrics,
return_hidden_states=(
self.return_hidden_states[i]
if isinstance(self.return_hidden_states, list)
else self.return_hidden_states
),
modalities=self.modalities[i] if self.modalities else None, modalities=self.modalities[i] if self.modalities else None,
session_params=self.session_params,
lora_path=self.lora_path[i] if self.lora_path is not None else None, lora_path=self.lora_path[i] if self.lora_path is not None else None,
lora_id=self.lora_id[i] if self.lora_id is not None else None, lora_id=self.lora_id[i] if self.lora_id is not None else None,
custom_logit_processor=( custom_logit_processor=(
...@@ -461,11 +517,6 @@ class GenerateReqInput: ...@@ -461,11 +517,6 @@ class GenerateReqInput:
if self.custom_logit_processor is not None if self.custom_logit_processor is not None
else None else None
), ),
return_hidden_states=(
self.return_hidden_states[i]
if isinstance(self.return_hidden_states, list)
else self.return_hidden_states
),
# if `__getitem__` is called, the bootstrap_host, bootstrap_port, bootstrap_room must be a list # if `__getitem__` is called, the bootstrap_host, bootstrap_port, bootstrap_room must be a list
bootstrap_host=( bootstrap_host=(
self.bootstrap_host[i] if self.bootstrap_host is not None else None self.bootstrap_host[i] if self.bootstrap_host is not None else None
...@@ -476,9 +527,17 @@ class GenerateReqInput: ...@@ -476,9 +527,17 @@ class GenerateReqInput:
bootstrap_room=( bootstrap_room=(
self.bootstrap_room[i] if self.bootstrap_room is not None else None self.bootstrap_room[i] if self.bootstrap_room is not None else None
), ),
bootstrap_pair_key=(
self.bootstrap_pair_key[i]
if self.bootstrap_pair_key is not None
else None
),
data_parallel_rank=( data_parallel_rank=(
self.data_parallel_rank if self.data_parallel_rank is not None else None self.data_parallel_rank if self.data_parallel_rank is not None else None
), ),
conversation_id=self.conversation_id,
label=self.label,
return_bytes=self.return_bytes,
) )
...@@ -504,27 +563,28 @@ class TokenizedGenerateReqInput: ...@@ -504,27 +563,28 @@ class TokenizedGenerateReqInput:
token_ids_logprob: List[int] token_ids_logprob: List[int]
# Whether to stream output # Whether to stream output
stream: bool stream: bool
# Whether to return hidden states
return_hidden_states: bool = False
# LoRA related
lora_id: Optional[str] = None # None means just use the base model
# The input embeds # The input embeds
input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
# Session info for continual prompting # Session info for continual prompting
session_params: Optional[SessionParams] = None session_params: Optional[SessionParams] = None
# LoRA related
lora_id: Optional[str] = None # None means just use the base model
# Custom logit processor for advanced sampling control. Must be a serialized instance # Custom logit processor for advanced sampling control. Must be a serialized instance
# of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py # of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py
# Use the processor's `to_str()` method to generate the serialized string. # Use the processor's `to_str()` method to generate the serialized string.
custom_logit_processor: Optional[str] = None custom_logit_processor: Optional[str] = None
# Whether to return hidden states
return_hidden_states: bool = False
# For disaggregated inference # For disaggregated inference
bootstrap_host: Optional[str] = None bootstrap_host: Optional[str] = None
bootstrap_port: Optional[int] = None bootstrap_port: Optional[int] = None
bootstrap_room: Optional[int] = None bootstrap_room: Optional[int] = None
bootstrap_pair_key: Optional[str] = None
# For data parallel rank routing # For data parallel rank routing
data_parallel_rank: Optional[int] = None data_parallel_rank: Optional[int] = None
...@@ -532,6 +592,12 @@ class TokenizedGenerateReqInput: ...@@ -532,6 +592,12 @@ class TokenizedGenerateReqInput:
# For dp balance # For dp balance
dp_balance_id: int = -1 dp_balance_id: int = -1
# Label for the request
label: Optional[str] = None
# Image gen grpc migration
return_bytes: bool = False
@dataclass @dataclass
class BatchTokenizedGenerateReqInput: class BatchTokenizedGenerateReqInput:
...@@ -738,9 +804,26 @@ class BatchTokenIDOut: ...@@ -738,9 +804,26 @@ class BatchTokenIDOut:
# Hidden states # Hidden states
output_hidden_states: List[List[float]] output_hidden_states: List[List[float]]
# The information of placeholder tokens (e.g., image token)
# idx is the index of the token in the prompt after expansion.
# val is the length of padded tokens after expansion.
placeholder_tokens_idx: List[Optional[List[int]]]
placeholder_tokens_val: List[Optional[List[int]]]
@dataclass @dataclass
class BatchMultimodalDecodeReq: class BatchMultimodalDecodeReq:
decoded_ids: List[int]
input_token_logprobs_val: List[float]
input_token_logprobs_idx: List[int]
output_token_logprobs_val: List[float]
output_token_logprobs_idx: List[int]
read_offsets: List[int]
skip_special_tokens: List[bool]
spaces_between_special_tokens: List[bool]
image_resolutions: List[List[int]]
resize_image_resolutions: List[List[int]]
# The request id # The request id
rids: List[str] rids: List[str]
finished_reasons: List[BaseFinishReason] finished_reasons: List[BaseFinishReason]
...@@ -750,6 +833,12 @@ class BatchMultimodalDecodeReq: ...@@ -750,6 +833,12 @@ class BatchMultimodalDecodeReq:
completion_tokens: List[int] completion_tokens: List[int]
cached_tokens: List[int] cached_tokens: List[int]
# Placeholder token info
placeholder_tokens_idx: List[Optional[List[int]]]
placeholder_tokens_val: List[Optional[List[int]]]
return_bytes: bool = False
@dataclass @dataclass
class BatchStrOut: class BatchStrOut:
...@@ -785,6 +874,9 @@ class BatchStrOut: ...@@ -785,6 +874,9 @@ class BatchStrOut:
# Hidden states # Hidden states
output_hidden_states: List[List[float]] output_hidden_states: List[List[float]]
placeholder_tokens_idx: List[Optional[List[int]]]
placeholder_tokens_val: List[Optional[List[int]]]
@dataclass @dataclass
class BatchMultimodalOut: class BatchMultimodalOut:
...@@ -792,14 +884,26 @@ class BatchMultimodalOut: ...@@ -792,14 +884,26 @@ class BatchMultimodalOut:
rids: List[str] rids: List[str]
# The finish reason # The finish reason
finished_reasons: List[dict] finished_reasons: List[dict]
decoded_ids: List[List[int]]
# The outputs # The outputs
outputs: List[List[Dict]] outputs: Union[List[str | bytes], List[List[Dict]]]
# probability values for input tokens and output tokens
input_token_logprobs_val: List[List[float]]
input_token_logprobs_idx: List[List[int]]
output_token_logprobs_val: List[List[float]]
output_token_logprobs_idx: List[List[int]]
# Token counts # Token counts
prompt_tokens: List[int] prompt_tokens: List[int]
completion_tokens: List[int] completion_tokens: List[int]
cached_tokens: List[int] cached_tokens: List[int]
placeholder_tokens_idx: List[Optional[List[int]]]
placeholder_tokens_val: List[Optional[List[int]]]
return_bytes: List[bool]
@dataclass @dataclass
class BatchEmbeddingOut: class BatchEmbeddingOut:
...@@ -812,6 +916,9 @@ class BatchEmbeddingOut: ...@@ -812,6 +916,9 @@ class BatchEmbeddingOut:
# Token counts # Token counts
prompt_tokens: List[int] prompt_tokens: List[int]
cached_tokens: List[int] cached_tokens: List[int]
# Placeholder token info
placeholder_tokens_idx: List[Optional[List[int]]]
placeholder_tokens_val: List[Optional[List[int]]]
@dataclass @dataclass
...@@ -844,6 +951,12 @@ class UpdateWeightFromDiskReqInput: ...@@ -844,6 +951,12 @@ class UpdateWeightFromDiskReqInput:
abort_all_requests: bool = False abort_all_requests: bool = False
# Optional: Update weight version along with weights # Optional: Update weight version along with weights
weight_version: Optional[str] = None weight_version: Optional[str] = None
# Whether to update weights asynchronously
is_async: bool = False
# Whether to empty torch cache
torch_empty_cache: bool = False
# Whether to keep the scheduler paused after weight update
keep_pause: bool = False
@dataclass @dataclass
...@@ -983,6 +1096,7 @@ class AbortReq: ...@@ -983,6 +1096,7 @@ class AbortReq:
abort_all: bool = False abort_all: bool = False
# The finished reason data # The finished reason data
finished_reason: Optional[Dict[str, Any]] = None finished_reason: Optional[Dict[str, Any]] = None
abort_reason: Optional[str] = None
# used in MultiTokenzierManager mode # used in MultiTokenzierManager mode
rids: Optional[Union[List[str], str]] = None rids: Optional[Union[List[str], str]] = None
...@@ -1061,6 +1175,7 @@ class ConfigureLoggingReq: ...@@ -1061,6 +1175,7 @@ class ConfigureLoggingReq:
log_requests_level: Optional[int] = None log_requests_level: Optional[int] = None
dump_requests_folder: Optional[str] = None dump_requests_folder: Optional[str] = None
dump_requests_threshold: Optional[int] = None dump_requests_threshold: Optional[int] = None
crash_dump_folder: Optional[str] = None
@dataclass @dataclass
......
...@@ -195,6 +195,8 @@ def _handle_output_by_index(output, i): ...@@ -195,6 +195,8 @@ def _handle_output_by_index(output, i):
if output.output_hidden_states if output.output_hidden_states
else None else None
), ),
placeholder_tokens_idx=None,
placeholder_tokens_val=None,
) )
elif isinstance(output, BatchEmbeddingOut): elif isinstance(output, BatchEmbeddingOut):
new_output = BatchEmbeddingOut( new_output = BatchEmbeddingOut(
...@@ -211,6 +213,8 @@ def _handle_output_by_index(output, i): ...@@ -211,6 +213,8 @@ def _handle_output_by_index(output, i):
cached_tokens=( cached_tokens=(
[output.cached_tokens[i]] if len(output.cached_tokens) > i else None [output.cached_tokens[i]] if len(output.cached_tokens) > i else None
), ),
placeholder_tokens_idx=None,
placeholder_tokens_val=None,
) )
elif isinstance(output, BatchStrOut): elif isinstance(output, BatchStrOut):
new_output = BatchStrOut( new_output = BatchStrOut(
...@@ -307,6 +311,8 @@ def _handle_output_by_index(output, i): ...@@ -307,6 +311,8 @@ def _handle_output_by_index(output, i):
if output.output_hidden_states if output.output_hidden_states
else None else None
), ),
placeholder_tokens_idx=None,
placeholder_tokens_val=None,
) )
elif isinstance(output, BatchMultimodalOut): elif isinstance(output, BatchMultimodalOut):
new_output = BatchMultimodalOut( new_output = BatchMultimodalOut(
...@@ -328,6 +334,8 @@ def _handle_output_by_index(output, i): ...@@ -328,6 +334,8 @@ def _handle_output_by_index(output, i):
cached_tokens=( cached_tokens=(
[output.cached_tokens[i]] if len(output.cached_tokens) > i else None [output.cached_tokens[i]] if len(output.cached_tokens) > i else None
), ),
placeholder_tokens_idx=None,
placeholder_tokens_val=None,
) )
else: else:
new_output = output new_output = output
......
...@@ -700,6 +700,8 @@ class SchedulerOutputProcessorMixin: ...@@ -700,6 +700,8 @@ class SchedulerOutputProcessorMixin:
output_token_ids_logprobs_val, output_token_ids_logprobs_val,
output_token_ids_logprobs_idx, output_token_ids_logprobs_idx,
output_hidden_states, output_hidden_states,
placeholder_tokens_idx=None,
placeholder_tokens_val=None,
) )
) )
...@@ -719,6 +721,12 @@ class SchedulerOutputProcessorMixin: ...@@ -719,6 +721,12 @@ class SchedulerOutputProcessorMixin:
cached_tokens.append(req.cached_tokens) cached_tokens.append(req.cached_tokens)
self.send_to_detokenizer.send_pyobj( self.send_to_detokenizer.send_pyobj(
BatchEmbeddingOut( BatchEmbeddingOut(
rids, finished_reasons, embeddings, prompt_tokens, cached_tokens rids,
finished_reasons,
embeddings,
prompt_tokens,
cached_tokens,
placeholder_tokens_idx=None,
placeholder_tokens_val=None,
) )
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment