"docs/vscode:/vscode.git/clone" did not exist on "8eba1684cbb6768ded0bd5cac580a4cd6a67111c"
Unverified Commit 3c699772 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Introduce naming convention in `io_struct` and base sglang io classes. (#10133)

parent e8100774
...@@ -22,8 +22,8 @@ import zmq.asyncio ...@@ -22,8 +22,8 @@ import zmq.asyncio
from sglang.srt.managers.disagg_service import start_disagg_service from sglang.srt.managers.disagg_service import start_disagg_service
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
AbortReq, AbortReq,
BatchEmbeddingOut, BatchEmbeddingOutput,
BatchTokenIDOut, BatchTokenIDOutput,
HealthCheckOutput, HealthCheckOutput,
TokenizedEmbeddingReqInput, TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput, TokenizedGenerateReqInput,
...@@ -467,9 +467,9 @@ class GrpcRequestManager: ...@@ -467,9 +467,9 @@ class GrpcRequestManager:
await self.is_pause_cond.wait() await self.is_pause_cond.wait()
# Handle different output types # Handle different output types
if isinstance(recv_obj, BatchTokenIDOut): if isinstance(recv_obj, BatchTokenIDOutput):
await self._handle_batch_output(recv_obj) await self._handle_batch_output(recv_obj)
elif isinstance(recv_obj, BatchEmbeddingOut): elif isinstance(recv_obj, BatchEmbeddingOutput):
await self._handle_embedding_output(recv_obj) await self._handle_embedding_output(recv_obj)
elif isinstance(recv_obj, HealthCheckOutput): elif isinstance(recv_obj, HealthCheckOutput):
await self._handle_health_check_output(recv_obj) await self._handle_health_check_output(recv_obj)
...@@ -498,7 +498,7 @@ class GrpcRequestManager: ...@@ -498,7 +498,7 @@ class GrpcRequestManager:
def _convert_logprob_style( def _convert_logprob_style(
self, self,
state: GrpcReqState, state: GrpcReqState,
batch_out: BatchTokenIDOut, batch_out: BatchTokenIDOutput,
batch_index: int, batch_index: int,
): ):
""" """
...@@ -545,7 +545,7 @@ class GrpcRequestManager: ...@@ -545,7 +545,7 @@ class GrpcRequestManager:
batch_out.output_top_logprobs_idx[batch_index] batch_out.output_top_logprobs_idx[batch_index]
) )
async def _handle_batch_output(self, batch_out: BatchTokenIDOut): async def _handle_batch_output(self, batch_out: BatchTokenIDOutput):
"""Handle batch generation output from scheduler.""" """Handle batch generation output from scheduler."""
# Process each request in the batch # Process each request in the batch
for i, rid in enumerate(batch_out.rids): for i, rid in enumerate(batch_out.rids):
...@@ -666,7 +666,7 @@ class GrpcRequestManager: ...@@ -666,7 +666,7 @@ class GrpcRequestManager:
asyncio.create_task(cleanup()) asyncio.create_task(cleanup())
async def _handle_embedding_output(self, batch_out: BatchEmbeddingOut): async def _handle_embedding_output(self, batch_out: BatchEmbeddingOutput):
"""Handle batch embedding output from scheduler.""" """Handle batch embedding output from scheduler."""
for i, rid in enumerate(batch_out.rids): for i, rid in enumerate(batch_out.rids):
if rid not in self.rid_to_state: if rid not in self.rid_to_state:
......
...@@ -94,8 +94,8 @@ from sglang.srt.managers.io_struct import ( ...@@ -94,8 +94,8 @@ from sglang.srt.managers.io_struct import (
VertexGenerateReqInput, VertexGenerateReqInput,
) )
from sglang.srt.managers.multi_tokenizer_mixin import ( from sglang.srt.managers.multi_tokenizer_mixin import (
MultiTokenizerManager,
MultiTokenizerRouter, MultiTokenizerRouter,
TokenizerWorker,
get_main_process_id, get_main_process_id,
monkey_patch_uvicorn_multiprocessing, monkey_patch_uvicorn_multiprocessing,
read_from_shared_memory, read_from_shared_memory,
...@@ -127,9 +127,7 @@ HEALTH_CHECK_TIMEOUT = int(os.getenv("SGLANG_HEALTH_CHECK_TIMEOUT", 20)) ...@@ -127,9 +127,7 @@ HEALTH_CHECK_TIMEOUT = int(os.getenv("SGLANG_HEALTH_CHECK_TIMEOUT", 20))
# Store global states # Store global states
@dataclasses.dataclass @dataclasses.dataclass
class _GlobalState: class _GlobalState:
tokenizer_manager: Union[ tokenizer_manager: Union[TokenizerManager, MultiTokenizerRouter, TokenizerWorker]
TokenizerManager, MultiTokenizerRouter, MultiTokenizerManager
]
template_manager: TemplateManager template_manager: TemplateManager
scheduler_info: Dict scheduler_info: Dict
...@@ -164,7 +162,7 @@ async def init_multi_tokenizer() -> ServerArgs: ...@@ -164,7 +162,7 @@ async def init_multi_tokenizer() -> ServerArgs:
) )
# Launch multi-tokenizer manager process # Launch multi-tokenizer manager process
tokenizer_manager = MultiTokenizerManager(server_args, port_args) tokenizer_manager = TokenizerWorker(server_args, port_args)
template_manager = TemplateManager() template_manager = TemplateManager()
template_manager.initialize_templates( template_manager.initialize_templates(
tokenizer_manager=tokenizer_manager, tokenizer_manager=tokenizer_manager,
......
...@@ -35,7 +35,7 @@ from sglang.srt.lora.utils import ( ...@@ -35,7 +35,7 @@ from sglang.srt.lora.utils import (
get_normalized_target_modules, get_normalized_target_modules,
get_target_module_name, get_target_module_name,
) )
from sglang.srt.managers.io_struct import LoRAUpdateResult from sglang.srt.managers.io_struct import LoRAUpdateOutput
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import replace_submodule from sglang.srt.utils import replace_submodule
...@@ -107,8 +107,8 @@ class LoRAManager: ...@@ -107,8 +107,8 @@ class LoRAManager:
def create_lora_update_result( def create_lora_update_result(
self, success: bool, error_message: str = "" self, success: bool, error_message: str = ""
) -> LoRAUpdateResult: ) -> LoRAUpdateOutput:
return LoRAUpdateResult( return LoRAUpdateOutput(
success=success, success=success,
error_message=error_message, error_message=error_message,
loaded_adapters={ loaded_adapters={
...@@ -117,7 +117,7 @@ class LoRAManager: ...@@ -117,7 +117,7 @@ class LoRAManager:
}, },
) )
def load_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateResult: def load_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateOutput:
""" """
Load a single LoRA adapter from the specified path. Load a single LoRA adapter from the specified path.
...@@ -174,7 +174,7 @@ class LoRAManager: ...@@ -174,7 +174,7 @@ class LoRAManager:
"`--max-loras-per-batch` or load it as unpinned LoRA adapters." "`--max-loras-per-batch` or load it as unpinned LoRA adapters."
) )
def unload_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateResult: def unload_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateOutput:
""" """
Unload LoRA adapters by their names. This will remove the adapters from the memory pool and Unload LoRA adapters by their names. This will remove the adapters from the memory pool and
delete the corresponding LoRA modules. delete the corresponding LoRA modules.
......
...@@ -26,11 +26,11 @@ import zmq ...@@ -26,11 +26,11 @@ import zmq
from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
BatchEmbeddingOut, BatchEmbeddingOutput,
BatchMultimodalDecodeReq, BatchMultimodalDecodeReq,
BatchMultimodalOut, BatchMultimodalOutput,
BatchStrOut, BatchStrOutput,
BatchTokenIDOut, BatchTokenIDOutput,
FreezeGCReq, FreezeGCReq,
MultiTokenizerRegisterReq, MultiTokenizerRegisterReq,
) )
...@@ -101,8 +101,8 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin): ...@@ -101,8 +101,8 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
self._request_dispatcher = TypeBasedDispatcher( self._request_dispatcher = TypeBasedDispatcher(
[ [
(BatchEmbeddingOut, self.handle_batch_embedding_out), (BatchEmbeddingOutput, self.handle_batch_embedding_out),
(BatchTokenIDOut, self.handle_batch_token_id_out), (BatchTokenIDOutput, self.handle_batch_token_id_out),
(BatchMultimodalDecodeReq, self.handle_multimodal_decode_req), (BatchMultimodalDecodeReq, self.handle_multimodal_decode_req),
(MultiTokenizerRegisterReq, lambda x: x), (MultiTokenizerRegisterReq, lambda x: x),
(FreezeGCReq, self.handle_freeze_gc_req), (FreezeGCReq, self.handle_freeze_gc_req),
...@@ -145,11 +145,11 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin): ...@@ -145,11 +145,11 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
return output[:-1] return output[:-1]
return output return output
def handle_batch_embedding_out(self, recv_obj: BatchEmbeddingOut): def handle_batch_embedding_out(self, recv_obj: BatchEmbeddingOutput):
# If it is embedding model, no detokenization is needed. # If it is embedding model, no detokenization is needed.
return recv_obj return recv_obj
def handle_batch_token_id_out(self, recv_obj: BatchTokenIDOut): def handle_batch_token_id_out(self, recv_obj: BatchTokenIDOutput):
bs = len(recv_obj.rids) bs = len(recv_obj.rids)
# Initialize decode status # Initialize decode status
...@@ -224,7 +224,7 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin): ...@@ -224,7 +224,7 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
s.sent_offset = len(output_str) s.sent_offset = len(output_str)
output_strs.append(incremental_output) output_strs.append(incremental_output)
return BatchStrOut( return BatchStrOutput(
rids=recv_obj.rids, rids=recv_obj.rids,
finished_reasons=recv_obj.finished_reasons, finished_reasons=recv_obj.finished_reasons,
output_strs=output_strs, output_strs=output_strs,
...@@ -252,7 +252,7 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin): ...@@ -252,7 +252,7 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
def handle_multimodal_decode_req(self, recv_obj: BatchMultimodalDecodeReq): def handle_multimodal_decode_req(self, recv_obj: BatchMultimodalDecodeReq):
outputs = self.tokenizer.detokenize(recv_obj) outputs = self.tokenizer.detokenize(recv_obj)
return BatchMultimodalOut( return BatchMultimodalOutput(
rids=recv_obj.rids, rids=recv_obj.rids,
finished_reasons=recv_obj.finished_reasons, finished_reasons=recv_obj.finished_reasons,
outputs=outputs, outputs=outputs,
......
...@@ -18,6 +18,7 @@ processes (TokenizerManager, DetokenizerManager, Scheduler). ...@@ -18,6 +18,7 @@ processes (TokenizerManager, DetokenizerManager, Scheduler).
import copy import copy
import uuid import uuid
from abc import ABC
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import Enum from enum import Enum
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
...@@ -36,10 +37,32 @@ else: ...@@ -36,10 +37,32 @@ else:
# Parameters for a session # Parameters for a session
@dataclass
class BaseReq(ABC):
rid: Optional[Union[str, List[str]]] = field(default=None, kw_only=True)
def regenerate_rid(self):
"""Generate a new request ID and return it."""
if isinstance(self.rid, list):
self.rid = [uuid.uuid4().hex for _ in range(len(self.rid))]
else:
self.rid = uuid.uuid4().hex
return self.rid
@dataclass
class BaseBatchReq(ABC):
rids: Optional[List[str]] = field(default=None, kw_only=True)
def regenerate_rids(self):
"""Generate new request IDs and return them."""
self.rids = [uuid.uuid4().hex for _ in range(len(self.rids))]
return self.rids
@dataclass @dataclass
class SessionParams: class SessionParams:
id: Optional[str] = None id: Optional[str] = None
rid: Optional[str] = None
offset: Optional[int] = None offset: Optional[int] = None
replace: Optional[bool] = None replace: Optional[bool] = None
drop_previous_output: Optional[bool] = None drop_previous_output: Optional[bool] = None
...@@ -63,7 +86,7 @@ MultimodalDataInputFormat = Union[ ...@@ -63,7 +86,7 @@ MultimodalDataInputFormat = Union[
@dataclass @dataclass
class GenerateReqInput: class GenerateReqInput(BaseReq):
# The input prompt. It can be a single prompt or a batch of prompts. # The input prompt. It can be a single prompt or a batch of prompts.
text: Optional[Union[List[str], str]] = None text: Optional[Union[List[str], str]] = None
# The token ids for text; one can specify either text or input_ids # The token ids for text; one can specify either text or input_ids
...@@ -83,8 +106,6 @@ class GenerateReqInput: ...@@ -83,8 +106,6 @@ class GenerateReqInput:
audio_data: Optional[MultimodalDataInputFormat] = None audio_data: Optional[MultimodalDataInputFormat] = None
# The sampling_params. See descriptions below. # The sampling_params. See descriptions below.
sampling_params: Optional[Union[List[Dict], Dict]] = None sampling_params: Optional[Union[List[Dict], Dict]] = None
# The request id.
rid: Optional[Union[List[str], str]] = None
# Whether to return logprobs. # Whether to return logprobs.
return_logprob: Optional[Union[List[bool], bool]] = None return_logprob: Optional[Union[List[bool], bool]] = None
# If return logprobs, the start location in the prompt for returning logprobs. # If return logprobs, the start location in the prompt for returning logprobs.
...@@ -491,11 +512,6 @@ class GenerateReqInput: ...@@ -491,11 +512,6 @@ class GenerateReqInput:
): ):
raise ValueError("Session params must be a dict or a list of dicts.") raise ValueError("Session params must be a dict or a list of dicts.")
def regenerate_rid(self):
"""Generate a new request ID and return it."""
self.rid = uuid.uuid4().hex
return self.rid
def __getitem__(self, i): def __getitem__(self, i):
return GenerateReqInput( return GenerateReqInput(
text=self.text[i] if self.text is not None else None, text=self.text[i] if self.text is not None else None,
...@@ -558,9 +574,7 @@ class GenerateReqInput: ...@@ -558,9 +574,7 @@ class GenerateReqInput:
@dataclass @dataclass
class TokenizedGenerateReqInput: class TokenizedGenerateReqInput(BaseReq):
# The request id
rid: str
# The input text # The input text
input_text: str input_text: str
# The input token ids # The input token ids
...@@ -625,7 +639,7 @@ class TokenizedGenerateReqInput: ...@@ -625,7 +639,7 @@ class TokenizedGenerateReqInput:
@dataclass @dataclass
class BatchTokenizedGenerateReqInput: class BatchTokenizedGenerateReqInput(BaseBatchReq):
# The batch of tokenized requests # The batch of tokenized requests
batch: List[TokenizedGenerateReqInput] batch: List[TokenizedGenerateReqInput]
...@@ -640,7 +654,7 @@ class BatchTokenizedGenerateReqInput: ...@@ -640,7 +654,7 @@ class BatchTokenizedGenerateReqInput:
@dataclass @dataclass
class EmbeddingReqInput: class EmbeddingReqInput(BaseReq):
# The input prompt. It can be a single prompt or a batch of prompts. # The input prompt. It can be a single prompt or a batch of prompts.
text: Optional[Union[List[List[str]], List[str], str]] = None text: Optional[Union[List[List[str]], List[str], str]] = None
# The image input. It can be an image instance, file name, URL, or base64 encoded string. # The image input. It can be an image instance, file name, URL, or base64 encoded string.
...@@ -656,8 +670,6 @@ class EmbeddingReqInput: ...@@ -656,8 +670,6 @@ class EmbeddingReqInput:
audio_data: Optional[MultimodalDataInputFormat] = None audio_data: Optional[MultimodalDataInputFormat] = None
# The token ids for text; one can either specify text or input_ids. # The token ids for text; one can either specify text or input_ids.
input_ids: Optional[Union[List[List[int]], List[int]]] = None input_ids: Optional[Union[List[List[int]], List[int]]] = None
# The request id.
rid: Optional[Union[List[str], str]] = None
# Dummy sampling params for compatibility # Dummy sampling params for compatibility
sampling_params: Optional[Union[List[Dict], Dict]] = None sampling_params: Optional[Union[List[Dict], Dict]] = None
# Dummy input embeds for compatibility # Dummy input embeds for compatibility
...@@ -728,10 +740,6 @@ class EmbeddingReqInput: ...@@ -728,10 +740,6 @@ class EmbeddingReqInput:
for i in range(self.batch_size): for i in range(self.batch_size):
self.sampling_params[i]["max_new_tokens"] = 0 self.sampling_params[i]["max_new_tokens"] = 0
def regenerate_rid(self):
self.rid = uuid.uuid4().hex
return self.rid
def contains_mm_input(self) -> bool: def contains_mm_input(self) -> bool:
return ( return (
has_valid_data(self.image_data) has_valid_data(self.image_data)
...@@ -760,9 +768,7 @@ class EmbeddingReqInput: ...@@ -760,9 +768,7 @@ class EmbeddingReqInput:
@dataclass @dataclass
class TokenizedEmbeddingReqInput: class TokenizedEmbeddingReqInput(BaseReq):
# The request id
rid: str
# The input text # The input text
input_text: str input_text: str
# The input token ids # The input token ids
...@@ -780,7 +786,7 @@ class TokenizedEmbeddingReqInput: ...@@ -780,7 +786,7 @@ class TokenizedEmbeddingReqInput:
@dataclass @dataclass
class BatchTokenizedEmbeddingReqInput: class BatchTokenizedEmbeddingReqInput(BaseBatchReq):
# The batch of tokenized embedding requests # The batch of tokenized embedding requests
batch: List[TokenizedEmbeddingReqInput] batch: List[TokenizedEmbeddingReqInput]
...@@ -795,9 +801,7 @@ class BatchTokenizedEmbeddingReqInput: ...@@ -795,9 +801,7 @@ class BatchTokenizedEmbeddingReqInput:
@dataclass @dataclass
class BatchTokenIDOut: class BatchTokenIDOutput(BaseBatchReq):
# The request id
rids: List[str]
# The finish reason # The finish reason
finished_reasons: List[BaseFinishReason] finished_reasons: List[BaseFinishReason]
# For incremental decoding # For incremental decoding
...@@ -842,7 +846,7 @@ class BatchTokenIDOut: ...@@ -842,7 +846,7 @@ class BatchTokenIDOut:
@dataclass @dataclass
class BatchMultimodalDecodeReq: class BatchMultimodalDecodeReq(BaseBatchReq):
decoded_ids: List[int] decoded_ids: List[int]
input_token_logprobs_val: List[float] input_token_logprobs_val: List[float]
input_token_logprobs_idx: List[int] input_token_logprobs_idx: List[int]
...@@ -854,8 +858,6 @@ class BatchMultimodalDecodeReq: ...@@ -854,8 +858,6 @@ class BatchMultimodalDecodeReq:
image_resolutions: List[List[int]] image_resolutions: List[List[int]]
resize_image_resolutions: List[List[int]] resize_image_resolutions: List[List[int]]
# The request id
rids: List[str]
finished_reasons: List[BaseFinishReason] finished_reasons: List[BaseFinishReason]
# Token counts # Token counts
...@@ -871,9 +873,7 @@ class BatchMultimodalDecodeReq: ...@@ -871,9 +873,7 @@ class BatchMultimodalDecodeReq:
@dataclass @dataclass
class BatchStrOut: class BatchStrOutput(BaseBatchReq):
# The request id
rids: List[str]
# The finish reason # The finish reason
finished_reasons: List[dict] finished_reasons: List[dict]
# The output decoded strings # The output decoded strings
...@@ -909,9 +909,7 @@ class BatchStrOut: ...@@ -909,9 +909,7 @@ class BatchStrOut:
@dataclass @dataclass
class BatchMultimodalOut: class BatchMultimodalOutput(BaseBatchReq):
# The request id
rids: List[str]
# The finish reason # The finish reason
finished_reasons: List[dict] finished_reasons: List[dict]
decoded_ids: List[List[int]] decoded_ids: List[List[int]]
...@@ -936,9 +934,7 @@ class BatchMultimodalOut: ...@@ -936,9 +934,7 @@ class BatchMultimodalOut:
@dataclass @dataclass
class BatchEmbeddingOut: class BatchEmbeddingOutput(BaseBatchReq):
# The request id
rids: List[str]
# The finish reason # The finish reason
finished_reasons: List[BaseFinishReason] finished_reasons: List[BaseFinishReason]
# The output embedding # The output embedding
...@@ -952,27 +948,27 @@ class BatchEmbeddingOut: ...@@ -952,27 +948,27 @@ class BatchEmbeddingOut:
@dataclass @dataclass
class ClearHiCacheReqInput: class ClearHiCacheReqInput(BaseReq):
pass pass
@dataclass @dataclass
class ClearHiCacheReqOutput: class ClearHiCacheReqOutput(BaseReq):
success: bool success: bool
@dataclass @dataclass
class FlushCacheReqInput: class FlushCacheReqInput(BaseReq):
pass pass
@dataclass @dataclass
class FlushCacheReqOutput: class FlushCacheReqOutput(BaseReq):
success: bool success: bool
@dataclass @dataclass
class UpdateWeightFromDiskReqInput: class UpdateWeightFromDiskReqInput(BaseReq):
# The model path with the new weights # The model path with the new weights
model_path: str model_path: str
# The format to load the weights # The format to load the weights
...@@ -990,7 +986,7 @@ class UpdateWeightFromDiskReqInput: ...@@ -990,7 +986,7 @@ class UpdateWeightFromDiskReqInput:
@dataclass @dataclass
class UpdateWeightFromDiskReqOutput: class UpdateWeightFromDiskReqOutput(BaseReq):
success: bool success: bool
message: str message: str
# Number of paused requests during weight sync. # Number of paused requests during weight sync.
...@@ -998,7 +994,7 @@ class UpdateWeightFromDiskReqOutput: ...@@ -998,7 +994,7 @@ class UpdateWeightFromDiskReqOutput:
@dataclass @dataclass
class UpdateWeightsFromDistributedReqInput: class UpdateWeightsFromDistributedReqInput(BaseReq):
names: List[str] names: List[str]
dtypes: List[str] dtypes: List[str]
shapes: List[List[int]] shapes: List[List[int]]
...@@ -1013,13 +1009,13 @@ class UpdateWeightsFromDistributedReqInput: ...@@ -1013,13 +1009,13 @@ class UpdateWeightsFromDistributedReqInput:
@dataclass @dataclass
class UpdateWeightsFromDistributedReqOutput: class UpdateWeightsFromDistributedReqOutput(BaseReq):
success: bool success: bool
message: str message: str
@dataclass @dataclass
class UpdateWeightsFromTensorReqInput: class UpdateWeightsFromTensorReqInput(BaseReq):
"""Update model weights from tensor input. """Update model weights from tensor input.
- Tensors are serialized for transmission - Tensors are serialized for transmission
...@@ -1038,13 +1034,13 @@ class UpdateWeightsFromTensorReqInput: ...@@ -1038,13 +1034,13 @@ class UpdateWeightsFromTensorReqInput:
@dataclass @dataclass
class UpdateWeightsFromTensorReqOutput: class UpdateWeightsFromTensorReqOutput(BaseReq):
success: bool success: bool
message: str message: str
@dataclass @dataclass
class InitWeightsSendGroupForRemoteInstanceReqInput: class InitWeightsSendGroupForRemoteInstanceReqInput(BaseReq):
# The master address # The master address
master_address: str master_address: str
# The ports for each rank's communication group # The ports for each rank's communication group
...@@ -1060,13 +1056,13 @@ class InitWeightsSendGroupForRemoteInstanceReqInput: ...@@ -1060,13 +1056,13 @@ class InitWeightsSendGroupForRemoteInstanceReqInput:
@dataclass @dataclass
class InitWeightsSendGroupForRemoteInstanceReqOutput: class InitWeightsSendGroupForRemoteInstanceReqOutput(BaseReq):
success: bool success: bool
message: str message: str
@dataclass @dataclass
class SendWeightsToRemoteInstanceReqInput: class SendWeightsToRemoteInstanceReqInput(BaseReq):
# The master address # The master address
master_address: str master_address: str
# The ports for each rank's communication group # The ports for each rank's communication group
...@@ -1076,13 +1072,13 @@ class SendWeightsToRemoteInstanceReqInput: ...@@ -1076,13 +1072,13 @@ class SendWeightsToRemoteInstanceReqInput:
@dataclass @dataclass
class SendWeightsToRemoteInstanceReqOutput: class SendWeightsToRemoteInstanceReqOutput(BaseReq):
success: bool success: bool
message: str message: str
@dataclass @dataclass
class InitWeightsUpdateGroupReqInput: class InitWeightsUpdateGroupReqInput(BaseReq):
# The master address # The master address
master_address: str master_address: str
# The master port # The master port
...@@ -1098,24 +1094,24 @@ class InitWeightsUpdateGroupReqInput: ...@@ -1098,24 +1094,24 @@ class InitWeightsUpdateGroupReqInput:
@dataclass @dataclass
class InitWeightsUpdateGroupReqOutput: class InitWeightsUpdateGroupReqOutput(BaseReq):
success: bool success: bool
message: str message: str
@dataclass @dataclass
class DestroyWeightsUpdateGroupReqInput: class DestroyWeightsUpdateGroupReqInput(BaseReq):
group_name: str = "weight_update_group" group_name: str = "weight_update_group"
@dataclass @dataclass
class DestroyWeightsUpdateGroupReqOutput: class DestroyWeightsUpdateGroupReqOutput(BaseReq):
success: bool success: bool
message: str message: str
@dataclass @dataclass
class UpdateWeightVersionReqInput: class UpdateWeightVersionReqInput(BaseReq):
# The new weight version # The new weight version
new_version: str new_version: str
# Whether to abort all running requests before updating # Whether to abort all running requests before updating
...@@ -1123,89 +1119,87 @@ class UpdateWeightVersionReqInput: ...@@ -1123,89 +1119,87 @@ class UpdateWeightVersionReqInput:
@dataclass @dataclass
class GetWeightsByNameReqInput: class GetWeightsByNameReqInput(BaseReq):
name: str name: str
truncate_size: int = 100 truncate_size: int = 100
@dataclass @dataclass
class GetWeightsByNameReqOutput: class GetWeightsByNameReqOutput(BaseReq):
parameter: list parameter: list
@dataclass @dataclass
class ReleaseMemoryOccupationReqInput: class ReleaseMemoryOccupationReqInput(BaseReq):
# Optional tags to identify the memory region, which is primarily used for RL # Optional tags to identify the memory region, which is primarily used for RL
# Currently we only support `weights` and `kv_cache` # Currently we only support `weights` and `kv_cache`
tags: Optional[List[str]] = None tags: Optional[List[str]] = None
@dataclass @dataclass
class ReleaseMemoryOccupationReqOutput: class ReleaseMemoryOccupationReqOutput(BaseReq):
pass pass
@dataclass @dataclass
class ResumeMemoryOccupationReqInput: class ResumeMemoryOccupationReqInput(BaseReq):
# Optional tags to identify the memory region, which is primarily used for RL # Optional tags to identify the memory region, which is primarily used for RL
# Currently we only support `weights` and `kv_cache` # Currently we only support `weights` and `kv_cache`
tags: Optional[List[str]] = None tags: Optional[List[str]] = None
@dataclass @dataclass
class ResumeMemoryOccupationReqOutput: class ResumeMemoryOccupationReqOutput(BaseReq):
pass pass
@dataclass @dataclass
class SlowDownReqInput: class SlowDownReqInput(BaseReq):
forward_sleep_time: Optional[float] forward_sleep_time: Optional[float]
@dataclass @dataclass
class SlowDownReqOutput: class SlowDownReqOutput(BaseReq):
pass pass
@dataclass @dataclass
class AbortReq: class AbortReq(BaseReq):
# The request id
rid: str = ""
# Whether to abort all requests # Whether to abort all requests
abort_all: bool = False abort_all: bool = False
# The finished reason data # The finished reason data
finished_reason: Optional[Dict[str, Any]] = None finished_reason: Optional[Dict[str, Any]] = None
abort_reason: Optional[str] = None abort_reason: Optional[str] = None
# used in MultiTokenzierManager mode
rids: Optional[Union[List[str], str]] = None
def __post_init__(self): def __post_init__(self):
self.rids = self.rid # FIXME: This is a hack to keep the same with the old code
if self.rid is None:
self.rid = ""
@dataclass @dataclass
class GetInternalStateReq: class GetInternalStateReq(BaseReq):
pass pass
@dataclass @dataclass
class GetInternalStateReqOutput: class GetInternalStateReqOutput(BaseReq):
internal_state: Dict[Any, Any] internal_state: Dict[Any, Any]
@dataclass @dataclass
class SetInternalStateReq: class SetInternalStateReq(BaseReq):
server_args: Dict[str, Any] server_args: Dict[str, Any]
@dataclass @dataclass
class SetInternalStateReqOutput: class SetInternalStateReqOutput(BaseReq):
updated: bool updated: bool
server_args: Dict[str, Any] server_args: Dict[str, Any]
@dataclass @dataclass
class ProfileReqInput: class ProfileReqInput(BaseReq):
# The output directory # The output directory
output_dir: Optional[str] = None output_dir: Optional[str] = None
# If set, it profile as many as this number of steps. # If set, it profile as many as this number of steps.
...@@ -1225,7 +1219,7 @@ class ProfileReqType(Enum): ...@@ -1225,7 +1219,7 @@ class ProfileReqType(Enum):
@dataclass @dataclass
class ProfileReq: class ProfileReq(BaseReq):
type: ProfileReqType type: ProfileReqType
output_dir: Optional[str] = None output_dir: Optional[str] = None
start_step: Optional[int] = None start_step: Optional[int] = None
...@@ -1238,18 +1232,18 @@ class ProfileReq: ...@@ -1238,18 +1232,18 @@ class ProfileReq:
@dataclass @dataclass
class ProfileReqOutput: class ProfileReqOutput(BaseReq):
success: bool success: bool
message: str message: str
@dataclass @dataclass
class FreezeGCReq: class FreezeGCReq(BaseReq):
pass pass
@dataclass @dataclass
class ConfigureLoggingReq: class ConfigureLoggingReq(BaseReq):
log_requests: Optional[bool] = None log_requests: Optional[bool] = None
log_requests_level: Optional[int] = None log_requests_level: Optional[int] = None
dump_requests_folder: Optional[str] = None dump_requests_folder: Optional[str] = None
...@@ -1258,35 +1252,39 @@ class ConfigureLoggingReq: ...@@ -1258,35 +1252,39 @@ class ConfigureLoggingReq:
@dataclass @dataclass
class OpenSessionReqInput: class OpenSessionReqInput(BaseReq):
capacity_of_str_len: int capacity_of_str_len: int
session_id: Optional[str] = None session_id: Optional[str] = None
@dataclass @dataclass
class CloseSessionReqInput: class CloseSessionReqInput(BaseReq):
session_id: str session_id: str
@dataclass @dataclass
class OpenSessionReqOutput: class OpenSessionReqOutput(BaseReq):
session_id: Optional[str] session_id: Optional[str]
success: bool success: bool
@dataclass @dataclass
class HealthCheckOutput: class HealthCheckOutput(BaseReq):
pass pass
class ExpertDistributionReq(Enum): class ExpertDistributionReqType(Enum):
START_RECORD = 1 START_RECORD = 1
STOP_RECORD = 2 STOP_RECORD = 2
DUMP_RECORD = 3 DUMP_RECORD = 3
class ExpertDistributionReq(BaseReq):
action: ExpertDistributionReqType
@dataclass @dataclass
class ExpertDistributionReqOutput: class ExpertDistributionReqOutput(BaseReq):
pass pass
...@@ -1304,7 +1302,7 @@ class Tool: ...@@ -1304,7 +1302,7 @@ class Tool:
@dataclass @dataclass
class ParseFunctionCallReq: class ParseFunctionCallReq(BaseReq):
text: str # The text to parse. text: str # The text to parse.
tools: List[Tool] = field( tools: List[Tool] = field(
default_factory=list default_factory=list
...@@ -1315,31 +1313,31 @@ class ParseFunctionCallReq: ...@@ -1315,31 +1313,31 @@ class ParseFunctionCallReq:
@dataclass @dataclass
class SeparateReasoningReqInput: class SeparateReasoningReqInput(BaseReq):
text: str # The text to parse. text: str # The text to parse.
reasoning_parser: str # Specify the parser type, e.g., "deepseek-r1". reasoning_parser: str # Specify the parser type, e.g., "deepseek-r1".
@dataclass @dataclass
class VertexGenerateReqInput: class VertexGenerateReqInput(BaseReq):
instances: List[dict] instances: List[dict]
parameters: Optional[dict] = None parameters: Optional[dict] = None
@dataclass @dataclass
class RpcReqInput: class RpcReqInput(BaseReq):
method: str method: str
parameters: Optional[Dict] = None parameters: Optional[Dict] = None
@dataclass @dataclass
class RpcReqOutput: class RpcReqOutput(BaseReq):
success: bool success: bool
message: str message: str
@dataclass @dataclass
class LoadLoRAAdapterReqInput: class LoadLoRAAdapterReqInput(BaseReq):
# The name of the lora module to newly loaded. # The name of the lora module to newly loaded.
lora_name: str lora_name: str
# The path of loading. # The path of loading.
...@@ -1359,7 +1357,7 @@ class LoadLoRAAdapterReqInput: ...@@ -1359,7 +1357,7 @@ class LoadLoRAAdapterReqInput:
@dataclass @dataclass
class UnloadLoRAAdapterReqInput: class UnloadLoRAAdapterReqInput(BaseReq):
# The name of lora module to unload. # The name of lora module to unload.
lora_name: str lora_name: str
# The unique identifier for the LoRA adapter, which automatically generated in the `TokenizerManager`. # The unique identifier for the LoRA adapter, which automatically generated in the `TokenizerManager`.
...@@ -1373,23 +1371,23 @@ class UnloadLoRAAdapterReqInput: ...@@ -1373,23 +1371,23 @@ class UnloadLoRAAdapterReqInput:
@dataclass @dataclass
class LoRAUpdateResult: class LoRAUpdateOutput(BaseReq):
success: bool success: bool
error_message: Optional[str] = None error_message: Optional[str] = None
loaded_adapters: Optional[Dict[str, LoRARef]] = None loaded_adapters: Optional[Dict[str, LoRARef]] = None
LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateOutput
@dataclass @dataclass
class MultiTokenizerRegisterReq: class MultiTokenizerRegisterReq(BaseBatchReq):
rids: Optional[Union[List[str], str]] = None
ipc_name: Optional[str] = None ipc_name: Optional[str] = None
@dataclass @dataclass
class MultiTokenizerWrapper: class MultiTokenizerWrapper:
# FIXME(lsyin): remove this
worker_id: int worker_id: int
obj: Optional[Any] = None obj: Optional[Any] = None
...@@ -1400,17 +1398,17 @@ class BlockReqType(Enum): ...@@ -1400,17 +1398,17 @@ class BlockReqType(Enum):
@dataclass @dataclass
class BlockReqInput: class BlockReqInput(BaseReq):
type: BlockReqType type: BlockReqType
@dataclass @dataclass
class GetLoadReqInput: class GetLoadReqInput(BaseReq):
pass pass
@dataclass @dataclass
class GetLoadReqOutput: class GetLoadReqOutput(BaseReq):
dp_rank: int dp_rank: int
num_reqs: int num_reqs: int
num_waiting_reqs: int num_waiting_reqs: int
...@@ -1418,5 +1416,31 @@ class GetLoadReqOutput: ...@@ -1418,5 +1416,31 @@ class GetLoadReqOutput:
@dataclass @dataclass
class WatchLoadUpdateReq: class WatchLoadUpdateReq(BaseReq):
loads: List[GetLoadReqOutput] loads: List[GetLoadReqOutput]
def _check_all_req_types():
"""A helper function to check all request types are defined in this file."""
import inspect
import sys
all_classes = inspect.getmembers(sys.modules[__name__], inspect.isclass)
for class_type in all_classes:
# check its name
name = class_type[0]
is_io_struct = (
name.endswith("Req") or name.endswith("Input") or name.endswith("Output")
)
is_base_req = issubclass(class_type[1], BaseReq) or issubclass(
class_type[1], BaseBatchReq
)
if is_io_struct and not is_base_req:
raise ValueError(f"{name} is not a subclass of BaseReq or BaseBatchReq.")
if is_base_req and not is_io_struct:
raise ValueError(
f"{name} is a subclass of BaseReq but not follow the naming convention."
)
_check_all_req_types()
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""MultiTokenizerMixin is a class that provides nesscary methods for MultiTokenizerManager and DetokenizerManager.""" """Mixin class and utils for multi-http-worker mode"""
import asyncio import asyncio
import logging import logging
import multiprocessing as multiprocessing import multiprocessing as multiprocessing
...@@ -30,10 +30,10 @@ import zmq.asyncio ...@@ -30,10 +30,10 @@ import zmq.asyncio
from sglang.srt.disaggregation.utils import DisaggregationMode, TransferBackend from sglang.srt.disaggregation.utils import DisaggregationMode, TransferBackend
from sglang.srt.managers.disagg_service import start_disagg_service from sglang.srt.managers.disagg_service import start_disagg_service
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
BatchEmbeddingOut, BatchEmbeddingOutput,
BatchMultimodalOut, BatchMultimodalOutput,
BatchStrOut, BatchStrOutput,
BatchTokenIDOut, BatchTokenIDOutput,
MultiTokenizerRegisterReq, MultiTokenizerRegisterReq,
MultiTokenizerWrapper, MultiTokenizerWrapper,
) )
...@@ -83,8 +83,8 @@ class SocketMapping: ...@@ -83,8 +83,8 @@ class SocketMapping:
def _handle_output_by_index(output, i): def _handle_output_by_index(output, i):
"""NOTE: A maintainable method is better here.""" """NOTE: A maintainable method is better here."""
if isinstance(output, BatchTokenIDOut): if isinstance(output, BatchTokenIDOutput):
new_output = BatchTokenIDOut( new_output = BatchTokenIDOutput(
rids=[output.rids[i]], rids=[output.rids[i]],
finished_reasons=( finished_reasons=(
[output.finished_reasons[i]] [output.finished_reasons[i]]
...@@ -198,8 +198,8 @@ def _handle_output_by_index(output, i): ...@@ -198,8 +198,8 @@ def _handle_output_by_index(output, i):
placeholder_tokens_idx=None, placeholder_tokens_idx=None,
placeholder_tokens_val=None, placeholder_tokens_val=None,
) )
elif isinstance(output, BatchEmbeddingOut): elif isinstance(output, BatchEmbeddingOutput):
new_output = BatchEmbeddingOut( new_output = BatchEmbeddingOutput(
rids=[output.rids[i]], rids=[output.rids[i]],
finished_reasons=( finished_reasons=(
[output.finished_reasons[i]] [output.finished_reasons[i]]
...@@ -216,8 +216,8 @@ def _handle_output_by_index(output, i): ...@@ -216,8 +216,8 @@ def _handle_output_by_index(output, i):
placeholder_tokens_idx=None, placeholder_tokens_idx=None,
placeholder_tokens_val=None, placeholder_tokens_val=None,
) )
elif isinstance(output, BatchStrOut): elif isinstance(output, BatchStrOutput):
new_output = BatchStrOut( new_output = BatchStrOutput(
rids=[output.rids[i]], rids=[output.rids[i]],
finished_reasons=( finished_reasons=(
[output.finished_reasons[i]] [output.finished_reasons[i]]
...@@ -314,8 +314,8 @@ def _handle_output_by_index(output, i): ...@@ -314,8 +314,8 @@ def _handle_output_by_index(output, i):
placeholder_tokens_idx=None, placeholder_tokens_idx=None,
placeholder_tokens_val=None, placeholder_tokens_val=None,
) )
elif isinstance(output, BatchMultimodalOut): elif isinstance(output, BatchMultimodalOutput):
new_output = BatchMultimodalOut( new_output = BatchMultimodalOutput(
rids=[output.rids[i]], rids=[output.rids[i]],
finished_reasons=( finished_reasons=(
[output.finished_reasons[i]] [output.finished_reasons[i]]
...@@ -343,7 +343,7 @@ def _handle_output_by_index(output, i): ...@@ -343,7 +343,7 @@ def _handle_output_by_index(output, i):
class MultiHttpWorkerDetokenizerMixin: class MultiHttpWorkerDetokenizerMixin:
"""Mixin class for MultiTokenizerManager and DetokenizerManager""" """Mixin class for DetokenizerManager"""
def get_worker_ids_from_req_rids(self, rids): def get_worker_ids_from_req_rids(self, rids):
if isinstance(rids, list): if isinstance(rids, list):
...@@ -386,7 +386,7 @@ class MultiHttpWorkerDetokenizerMixin: ...@@ -386,7 +386,7 @@ class MultiHttpWorkerDetokenizerMixin:
class MultiTokenizerRouter: class MultiTokenizerRouter:
"""A router to receive requests from MultiTokenizerManager""" """A router to receive requests from TokenizerWorker"""
def __init__( def __init__(
self, self,
...@@ -454,8 +454,8 @@ class MultiTokenizerRouter: ...@@ -454,8 +454,8 @@ class MultiTokenizerRouter:
self.socket_mapping.send_output(worker_id, new_recv_obj) self.socket_mapping.send_output(worker_id, new_recv_obj)
class MultiTokenizerManager(TokenizerManager): class TokenizerWorker(TokenizerManager):
"""Multi Process Tokenizer Manager that tokenizes the text.""" """Tokenizer Worker in multi-http-worker mode"""
def __init__( def __init__(
self, self,
......
...@@ -78,6 +78,7 @@ from sglang.srt.managers.io_struct import ( ...@@ -78,6 +78,7 @@ from sglang.srt.managers.io_struct import (
DestroyWeightsUpdateGroupReqInput, DestroyWeightsUpdateGroupReqInput,
ExpertDistributionReq, ExpertDistributionReq,
ExpertDistributionReqOutput, ExpertDistributionReqOutput,
ExpertDistributionReqType,
FlushCacheReqInput, FlushCacheReqInput,
FlushCacheReqOutput, FlushCacheReqOutput,
FreezeGCReq, FreezeGCReq,
...@@ -1487,12 +1488,12 @@ class Scheduler( ...@@ -1487,12 +1488,12 @@ class Scheduler(
req.priority = -sys.maxsize - 1 req.priority = -sys.maxsize - 1
elif not self.enable_priority_scheduling and req.priority is not None: elif not self.enable_priority_scheduling and req.priority is not None:
abort_req = AbortReq( abort_req = AbortReq(
req.rid,
finished_reason={ finished_reason={
"type": "abort", "type": "abort",
"status_code": HTTPStatus.SERVICE_UNAVAILABLE, "status_code": HTTPStatus.SERVICE_UNAVAILABLE,
"message": "Using priority is disabled for this server. Please send a new request without a priority.", "message": "Using priority is disabled for this server. Please send a new request without a priority.",
}, },
rid=req.rid,
) )
self.send_to_tokenizer.send_pyobj(abort_req) self.send_to_tokenizer.send_pyobj(abort_req)
...@@ -1528,12 +1529,12 @@ class Scheduler( ...@@ -1528,12 +1529,12 @@ class Scheduler(
self.send_to_tokenizer.send_pyobj( self.send_to_tokenizer.send_pyobj(
AbortReq( AbortReq(
req_to_abort.rid,
finished_reason={ finished_reason={
"type": "abort", "type": "abort",
"status_code": HTTPStatus.SERVICE_UNAVAILABLE, "status_code": HTTPStatus.SERVICE_UNAVAILABLE,
"message": message, "message": message,
}, },
rid=req_to_abort.rid,
) )
) )
return req_to_abort.rid == recv_req.rid return req_to_abort.rid == recv_req.rid
...@@ -2005,7 +2006,7 @@ class Scheduler( ...@@ -2005,7 +2006,7 @@ class Scheduler(
self.new_token_ratio = new_token_ratio self.new_token_ratio = new_token_ratio
for req in reqs_to_abort: for req in reqs_to_abort:
self.send_to_tokenizer.send_pyobj( self.send_to_tokenizer.send_pyobj(
AbortReq(req.rid, abort_reason=req.to_abort_message) AbortReq(abort_reason=req.to_abort_message, rid=req.rid)
) )
logger.info( logger.info(
...@@ -2575,7 +2576,7 @@ class Scheduler( ...@@ -2575,7 +2576,7 @@ class Scheduler(
if self.enable_hicache_storage: if self.enable_hicache_storage:
# to release prefetch events associated with the request # to release prefetch events associated with the request
self.tree_cache.release_aborted_request(req.rid) self.tree_cache.release_aborted_request(req.rid)
self.send_to_tokenizer.send_pyobj(AbortReq(req.rid)) self.send_to_tokenizer.send_pyobj(AbortReq(rid=req.rid))
# For disaggregation decode mode, the request in the waiting queue has KV cache allocated. # For disaggregation decode mode, the request in the waiting queue has KV cache allocated.
if self.disaggregation_mode == DisaggregationMode.DECODE: if self.disaggregation_mode == DisaggregationMode.DECODE:
self.tree_cache.cache_finished_req(req) self.tree_cache.cache_finished_req(req)
...@@ -2687,11 +2688,12 @@ class Scheduler( ...@@ -2687,11 +2688,12 @@ class Scheduler(
return SlowDownReqOutput() return SlowDownReqOutput()
def expert_distribution_handle(self, recv_req: ExpertDistributionReq): def expert_distribution_handle(self, recv_req: ExpertDistributionReq):
if recv_req == ExpertDistributionReq.START_RECORD: action = recv_req.action
if action == ExpertDistributionReqType.START_RECORD:
get_global_expert_distribution_recorder().start_record() get_global_expert_distribution_recorder().start_record()
elif recv_req == ExpertDistributionReq.STOP_RECORD: elif action == ExpertDistributionReqType.STOP_RECORD:
get_global_expert_distribution_recorder().stop_record() get_global_expert_distribution_recorder().stop_record()
elif recv_req == ExpertDistributionReq.DUMP_RECORD: elif action == ExpertDistributionReqType.DUMP_RECORD:
get_global_expert_distribution_recorder().dump_record() get_global_expert_distribution_recorder().dump_record()
else: else:
raise ValueError(f"Unrecognized ExpertDistributionReq value: {recv_req=}") raise ValueError(f"Unrecognized ExpertDistributionReq value: {recv_req=}")
...@@ -2774,7 +2776,8 @@ class IdleSleeper: ...@@ -2774,7 +2776,8 @@ class IdleSleeper:
def is_health_check_generate_req(recv_req): def is_health_check_generate_req(recv_req):
return getattr(recv_req, "rid", "").startswith("HEALTH_CHECK") rid = getattr(recv_req, "rid", None)
return rid is not None and rid.startswith("HEALTH_CHECK")
def is_work_request(recv_req): def is_work_request(recv_req):
......
...@@ -9,7 +9,11 @@ import torch ...@@ -9,7 +9,11 @@ import torch
from sglang.srt.disaggregation.utils import DisaggregationMode from sglang.srt.disaggregation.utils import DisaggregationMode
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.io_struct import AbortReq, BatchEmbeddingOut, BatchTokenIDOut from sglang.srt.managers.io_struct import (
AbortReq,
BatchEmbeddingOutput,
BatchTokenIDOutput,
)
from sglang.srt.managers.schedule_batch import BaseFinishReason, Req, ScheduleBatch from sglang.srt.managers.schedule_batch import BaseFinishReason, Req, ScheduleBatch
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -140,7 +144,7 @@ class SchedulerOutputProcessorMixin: ...@@ -140,7 +144,7 @@ class SchedulerOutputProcessorMixin:
logger.error( logger.error(
f"Grammar accept_token failed for req {req.rid} with token {next_token_id}: {e}" f"Grammar accept_token failed for req {req.rid} with token {next_token_id}: {e}"
) )
self.abort_request(AbortReq(req.rid)) self.abort_request(AbortReq(rid=req.rid))
req.grammar.finished = req.finished() req.grammar.finished = req.finished()
else: else:
# being chunked reqs' prefill is not finished # being chunked reqs' prefill is not finished
...@@ -292,7 +296,7 @@ class SchedulerOutputProcessorMixin: ...@@ -292,7 +296,7 @@ class SchedulerOutputProcessorMixin:
logger.error( logger.error(
f"Grammar accept_token failed for req {req.rid} with token {next_token_id}: {e}" f"Grammar accept_token failed for req {req.rid} with token {next_token_id}: {e}"
) )
self.abort_request(AbortReq(req.rid)) self.abort_request(AbortReq(rid=req.rid))
req.grammar.finished = req.finished() req.grammar.finished = req.finished()
self.set_next_batch_sampling_info_done(batch) self.set_next_batch_sampling_info_done(batch)
...@@ -714,8 +718,7 @@ class SchedulerOutputProcessorMixin: ...@@ -714,8 +718,7 @@ class SchedulerOutputProcessorMixin:
return return
self.send_to_detokenizer.send_pyobj( self.send_to_detokenizer.send_pyobj(
BatchTokenIDOut( BatchTokenIDOutput(
rids,
finished_reasons, finished_reasons,
decoded_texts, decoded_texts,
decode_ids_list, decode_ids_list,
...@@ -741,6 +744,7 @@ class SchedulerOutputProcessorMixin: ...@@ -741,6 +744,7 @@ class SchedulerOutputProcessorMixin:
output_token_ids_logprobs_val, output_token_ids_logprobs_val,
output_token_ids_logprobs_idx, output_token_ids_logprobs_idx,
output_hidden_states, output_hidden_states,
rids=rids,
placeholder_tokens_idx=None, placeholder_tokens_idx=None,
placeholder_tokens_val=None, placeholder_tokens_val=None,
) )
...@@ -761,12 +765,12 @@ class SchedulerOutputProcessorMixin: ...@@ -761,12 +765,12 @@ class SchedulerOutputProcessorMixin:
prompt_tokens.append(len(req.origin_input_ids)) prompt_tokens.append(len(req.origin_input_ids))
cached_tokens.append(req.cached_tokens) cached_tokens.append(req.cached_tokens)
self.send_to_detokenizer.send_pyobj( self.send_to_detokenizer.send_pyobj(
BatchEmbeddingOut( BatchEmbeddingOutput(
rids,
finished_reasons, finished_reasons,
embeddings, embeddings,
prompt_tokens, prompt_tokens,
cached_tokens, cached_tokens,
rids=rids,
placeholder_tokens_idx=None, placeholder_tokens_idx=None,
placeholder_tokens_val=None, placeholder_tokens_val=None,
) )
......
...@@ -30,6 +30,7 @@ from sglang.srt.managers.io_struct import ( ...@@ -30,6 +30,7 @@ from sglang.srt.managers.io_struct import (
DestroyWeightsUpdateGroupReqOutput, DestroyWeightsUpdateGroupReqOutput,
ExpertDistributionReq, ExpertDistributionReq,
ExpertDistributionReqOutput, ExpertDistributionReqOutput,
ExpertDistributionReqType,
FlushCacheReqInput, FlushCacheReqInput,
FlushCacheReqOutput, FlushCacheReqOutput,
GetInternalStateReq, GetInternalStateReq,
...@@ -44,7 +45,7 @@ from sglang.srt.managers.io_struct import ( ...@@ -44,7 +45,7 @@ from sglang.srt.managers.io_struct import (
InitWeightsUpdateGroupReqOutput, InitWeightsUpdateGroupReqOutput,
LoadLoRAAdapterReqInput, LoadLoRAAdapterReqInput,
LoadLoRAAdapterReqOutput, LoadLoRAAdapterReqOutput,
LoRAUpdateResult, LoRAUpdateOutput,
MultiTokenizerWrapper, MultiTokenizerWrapper,
OpenSessionReqInput, OpenSessionReqInput,
ProfileReq, ProfileReq,
...@@ -276,7 +277,7 @@ class TokenizerCommunicatorMixin: ...@@ -276,7 +277,7 @@ class TokenizerCommunicatorMixin:
self.expert_distribution_communicator.handle_recv, self.expert_distribution_communicator.handle_recv,
), ),
( (
LoRAUpdateResult, LoRAUpdateOutput,
self.update_lora_adapter_communicator.handle_recv, self.update_lora_adapter_communicator.handle_recv,
), ),
( (
...@@ -335,15 +336,18 @@ class TokenizerCommunicatorMixin: ...@@ -335,15 +336,18 @@ class TokenizerCommunicatorMixin:
async def start_expert_distribution_record(self: TokenizerManager): async def start_expert_distribution_record(self: TokenizerManager):
self.auto_create_handle_loop() self.auto_create_handle_loop()
await self.expert_distribution_communicator(ExpertDistributionReq.START_RECORD) req = ExpertDistributionReq(action=ExpertDistributionReqType.START_RECORD)
await self.expert_distribution_communicator(req)
async def stop_expert_distribution_record(self: TokenizerManager): async def stop_expert_distribution_record(self: TokenizerManager):
self.auto_create_handle_loop() self.auto_create_handle_loop()
await self.expert_distribution_communicator(ExpertDistributionReq.STOP_RECORD) req = ExpertDistributionReq(action=ExpertDistributionReqType.STOP_RECORD)
await self.expert_distribution_communicator(req)
async def dump_expert_distribution_record(self: TokenizerManager): async def dump_expert_distribution_record(self: TokenizerManager):
self.auto_create_handle_loop() self.auto_create_handle_loop()
await self.expert_distribution_communicator(ExpertDistributionReq.DUMP_RECORD) req = ExpertDistributionReq(action=ExpertDistributionReqType.DUMP_RECORD)
await self.expert_distribution_communicator(req)
async def init_weights_update_group( async def init_weights_update_group(
self: TokenizerManager, self: TokenizerManager,
......
...@@ -48,18 +48,17 @@ from sglang.srt.hf_transformers_utils import ( ...@@ -48,18 +48,17 @@ from sglang.srt.hf_transformers_utils import (
get_tokenizer, get_tokenizer,
get_tokenizer_from_processor, get_tokenizer_from_processor,
) )
from sglang.srt.lora.lora_registry import LoRARef, LoRARegistry from sglang.srt.lora.lora_registry import LoRARegistry
from sglang.srt.managers.async_dynamic_batch_tokenizer import AsyncDynamicbatchTokenizer from sglang.srt.managers.async_dynamic_batch_tokenizer import AsyncDynamicbatchTokenizer
from sglang.srt.managers.disagg_service import start_disagg_service from sglang.srt.managers.disagg_service import start_disagg_service
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
AbortReq, AbortReq,
BatchEmbeddingOut, BatchEmbeddingOutput,
BatchMultimodalOut, BatchMultimodalOutput,
BatchStrOut, BatchStrOutput,
BatchTokenIDOut, BatchTokenIDOutput,
BatchTokenizedEmbeddingReqInput, BatchTokenizedEmbeddingReqInput,
BatchTokenizedGenerateReqInput, BatchTokenizedGenerateReqInput,
CloseSessionReqInput,
ConfigureLoggingReq, ConfigureLoggingReq,
EmbeddingReqInput, EmbeddingReqInput,
FreezeGCReq, FreezeGCReq,
...@@ -67,7 +66,6 @@ from sglang.srt.managers.io_struct import ( ...@@ -67,7 +66,6 @@ from sglang.srt.managers.io_struct import (
GetLoadReqInput, GetLoadReqInput,
HealthCheckOutput, HealthCheckOutput,
MultiTokenizerWrapper, MultiTokenizerWrapper,
OpenSessionReqInput,
OpenSessionReqOutput, OpenSessionReqOutput,
SessionParams, SessionParams,
TokenizedEmbeddingReqInput, TokenizedEmbeddingReqInput,
...@@ -341,10 +339,10 @@ class TokenizerManager(TokenizerCommunicatorMixin): ...@@ -341,10 +339,10 @@ class TokenizerManager(TokenizerCommunicatorMixin):
[ [
( (
( (
BatchStrOut, BatchStrOutput,
BatchEmbeddingOut, BatchEmbeddingOutput,
BatchTokenIDOut, BatchTokenIDOutput,
BatchMultimodalOut, BatchMultimodalOutput,
), ),
self._handle_batch_output, self._handle_batch_output,
), ),
...@@ -716,7 +714,6 @@ class TokenizerManager(TokenizerCommunicatorMixin): ...@@ -716,7 +714,6 @@ class TokenizerManager(TokenizerCommunicatorMixin):
) )
tokenized_obj = TokenizedGenerateReqInput( tokenized_obj = TokenizedGenerateReqInput(
obj.rid,
input_text, input_text,
input_ids, input_ids,
mm_inputs, mm_inputs,
...@@ -726,6 +723,7 @@ class TokenizerManager(TokenizerCommunicatorMixin): ...@@ -726,6 +723,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
obj.top_logprobs_num, obj.top_logprobs_num,
obj.token_ids_logprob, obj.token_ids_logprob,
obj.stream, obj.stream,
rid=obj.rid,
bootstrap_host=obj.bootstrap_host, bootstrap_host=obj.bootstrap_host,
bootstrap_port=obj.bootstrap_port, bootstrap_port=obj.bootstrap_port,
bootstrap_room=obj.bootstrap_room, bootstrap_room=obj.bootstrap_room,
...@@ -740,12 +738,12 @@ class TokenizerManager(TokenizerCommunicatorMixin): ...@@ -740,12 +738,12 @@ class TokenizerManager(TokenizerCommunicatorMixin):
) )
elif isinstance(obj, EmbeddingReqInput): elif isinstance(obj, EmbeddingReqInput):
tokenized_obj = TokenizedEmbeddingReqInput( tokenized_obj = TokenizedEmbeddingReqInput(
obj.rid,
input_text, input_text,
input_ids, input_ids,
mm_inputs, mm_inputs,
token_type_ids, token_type_ids,
sampling_params, sampling_params,
rid=obj.rid,
priority=obj.priority, priority=obj.priority,
) )
...@@ -1038,7 +1036,7 @@ class TokenizerManager(TokenizerCommunicatorMixin): ...@@ -1038,7 +1036,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
def abort_request(self, rid: str = "", abort_all: bool = False): def abort_request(self, rid: str = "", abort_all: bool = False):
if not abort_all and rid not in self.rid_to_state: if not abort_all and rid not in self.rid_to_state:
return return
req = AbortReq(rid, abort_all) req = AbortReq(rid=rid, abort_all=abort_all)
self.send_to_scheduler.send_pyobj(req) self.send_to_scheduler.send_pyobj(req)
if self.enable_metrics: if self.enable_metrics:
# TODO: also use custom_labels from the request # TODO: also use custom_labels from the request
...@@ -1303,7 +1301,10 @@ class TokenizerManager(TokenizerCommunicatorMixin): ...@@ -1303,7 +1301,10 @@ class TokenizerManager(TokenizerCommunicatorMixin):
def _handle_batch_output( def _handle_batch_output(
self, self,
recv_obj: Union[ recv_obj: Union[
BatchStrOut, BatchEmbeddingOut, BatchMultimodalOut, BatchTokenIDOut BatchStrOutput,
BatchEmbeddingOutput,
BatchMultimodalOutput,
BatchTokenIDOutput,
], ],
): ):
for i, rid in enumerate(recv_obj.rids): for i, rid in enumerate(recv_obj.rids):
...@@ -1337,7 +1338,7 @@ class TokenizerManager(TokenizerCommunicatorMixin): ...@@ -1337,7 +1338,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
i, i,
) )
if not isinstance(recv_obj, BatchEmbeddingOut): if not isinstance(recv_obj, BatchEmbeddingOutput):
meta_info.update( meta_info.update(
{ {
"completion_tokens": recv_obj.completion_tokens[i], "completion_tokens": recv_obj.completion_tokens[i],
...@@ -1348,7 +1349,7 @@ class TokenizerManager(TokenizerCommunicatorMixin): ...@@ -1348,7 +1349,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
if getattr(recv_obj, "output_hidden_states", None): if getattr(recv_obj, "output_hidden_states", None):
meta_info["hidden_states"] = recv_obj.output_hidden_states[i] meta_info["hidden_states"] = recv_obj.output_hidden_states[i]
if isinstance(recv_obj, BatchStrOut): if isinstance(recv_obj, BatchStrOutput):
state.text += recv_obj.output_strs[i] state.text += recv_obj.output_strs[i]
if state.obj.stream: if state.obj.stream:
state.output_ids.extend(recv_obj.output_ids[i]) state.output_ids.extend(recv_obj.output_ids[i])
...@@ -1363,7 +1364,7 @@ class TokenizerManager(TokenizerCommunicatorMixin): ...@@ -1363,7 +1364,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
"output_ids": output_token_ids, "output_ids": output_token_ids,
"meta_info": meta_info, "meta_info": meta_info,
} }
elif isinstance(recv_obj, BatchTokenIDOut): elif isinstance(recv_obj, BatchTokenIDOutput):
if self.server_args.stream_output and state.obj.stream: if self.server_args.stream_output and state.obj.stream:
state.output_ids.extend(recv_obj.output_ids[i]) state.output_ids.extend(recv_obj.output_ids[i])
output_token_ids = state.output_ids[state.last_output_offset :] output_token_ids = state.output_ids[state.last_output_offset :]
...@@ -1376,10 +1377,10 @@ class TokenizerManager(TokenizerCommunicatorMixin): ...@@ -1376,10 +1377,10 @@ class TokenizerManager(TokenizerCommunicatorMixin):
"output_ids": output_token_ids, "output_ids": output_token_ids,
"meta_info": meta_info, "meta_info": meta_info,
} }
elif isinstance(recv_obj, BatchMultimodalOut): elif isinstance(recv_obj, BatchMultimodalOutput):
raise NotImplementedError("BatchMultimodalOut not implemented") raise NotImplementedError("BatchMultimodalOut not implemented")
else: else:
assert isinstance(recv_obj, BatchEmbeddingOut) assert isinstance(recv_obj, BatchEmbeddingOutput)
out_dict = { out_dict = {
"embedding": recv_obj.embeddings[i], "embedding": recv_obj.embeddings[i],
"meta_info": meta_info, "meta_info": meta_info,
...@@ -1418,7 +1419,7 @@ class TokenizerManager(TokenizerCommunicatorMixin): ...@@ -1418,7 +1419,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
top_logprobs_num: int, top_logprobs_num: int,
token_ids_logprob: List[int], token_ids_logprob: List[int],
return_text_in_logprobs: bool, return_text_in_logprobs: bool,
recv_obj: BatchStrOut, recv_obj: BatchStrOutput,
recv_obj_index: int, recv_obj_index: int,
): ):
if recv_obj.input_token_logprobs_val is None: if recv_obj.input_token_logprobs_val is None:
...@@ -1536,7 +1537,7 @@ class TokenizerManager(TokenizerCommunicatorMixin): ...@@ -1536,7 +1537,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
ret.append(None) ret.append(None)
return ret return ret
def collect_metrics(self, state: ReqState, recv_obj: BatchStrOut, i: int): def collect_metrics(self, state: ReqState, recv_obj: BatchStrOutput, i: int):
completion_tokens = ( completion_tokens = (
recv_obj.completion_tokens[i] recv_obj.completion_tokens[i]
if getattr(recv_obj, "completion_tokens", None) if getattr(recv_obj, "completion_tokens", None)
...@@ -1632,7 +1633,7 @@ class TokenizerManager(TokenizerCommunicatorMixin): ...@@ -1632,7 +1633,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
asyncio.create_task(asyncio.to_thread(background_task)) asyncio.create_task(asyncio.to_thread(background_task))
def _handle_abort_req(self, recv_obj): def _handle_abort_req(self, recv_obj: AbortReq):
if is_health_check_generate_req(recv_obj): if is_health_check_generate_req(recv_obj):
return return
state = self.rid_to_state[recv_obj.rid] state = self.rid_to_state[recv_obj.rid]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment