Unverified Commit 3c699772 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Introduce naming convention in `io_struct` and base sglang io classes. (#10133)

parent e8100774
......@@ -22,8 +22,8 @@ import zmq.asyncio
from sglang.srt.managers.disagg_service import start_disagg_service
from sglang.srt.managers.io_struct import (
AbortReq,
BatchEmbeddingOut,
BatchTokenIDOut,
BatchEmbeddingOutput,
BatchTokenIDOutput,
HealthCheckOutput,
TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput,
......@@ -467,9 +467,9 @@ class GrpcRequestManager:
await self.is_pause_cond.wait()
# Handle different output types
if isinstance(recv_obj, BatchTokenIDOut):
if isinstance(recv_obj, BatchTokenIDOutput):
await self._handle_batch_output(recv_obj)
elif isinstance(recv_obj, BatchEmbeddingOut):
elif isinstance(recv_obj, BatchEmbeddingOutput):
await self._handle_embedding_output(recv_obj)
elif isinstance(recv_obj, HealthCheckOutput):
await self._handle_health_check_output(recv_obj)
......@@ -498,7 +498,7 @@ class GrpcRequestManager:
def _convert_logprob_style(
self,
state: GrpcReqState,
batch_out: BatchTokenIDOut,
batch_out: BatchTokenIDOutput,
batch_index: int,
):
"""
......@@ -545,7 +545,7 @@ class GrpcRequestManager:
batch_out.output_top_logprobs_idx[batch_index]
)
async def _handle_batch_output(self, batch_out: BatchTokenIDOut):
async def _handle_batch_output(self, batch_out: BatchTokenIDOutput):
"""Handle batch generation output from scheduler."""
# Process each request in the batch
for i, rid in enumerate(batch_out.rids):
......@@ -666,7 +666,7 @@ class GrpcRequestManager:
asyncio.create_task(cleanup())
async def _handle_embedding_output(self, batch_out: BatchEmbeddingOut):
async def _handle_embedding_output(self, batch_out: BatchEmbeddingOutput):
"""Handle batch embedding output from scheduler."""
for i, rid in enumerate(batch_out.rids):
if rid not in self.rid_to_state:
......
......@@ -94,8 +94,8 @@ from sglang.srt.managers.io_struct import (
VertexGenerateReqInput,
)
from sglang.srt.managers.multi_tokenizer_mixin import (
MultiTokenizerManager,
MultiTokenizerRouter,
TokenizerWorker,
get_main_process_id,
monkey_patch_uvicorn_multiprocessing,
read_from_shared_memory,
......@@ -127,9 +127,7 @@ HEALTH_CHECK_TIMEOUT = int(os.getenv("SGLANG_HEALTH_CHECK_TIMEOUT", 20))
# Store global states
@dataclasses.dataclass
class _GlobalState:
tokenizer_manager: Union[
TokenizerManager, MultiTokenizerRouter, MultiTokenizerManager
]
tokenizer_manager: Union[TokenizerManager, MultiTokenizerRouter, TokenizerWorker]
template_manager: TemplateManager
scheduler_info: Dict
......@@ -164,7 +162,7 @@ async def init_multi_tokenizer() -> ServerArgs:
)
# Launch multi-tokenizer manager process
tokenizer_manager = MultiTokenizerManager(server_args, port_args)
tokenizer_manager = TokenizerWorker(server_args, port_args)
template_manager = TemplateManager()
template_manager.initialize_templates(
tokenizer_manager=tokenizer_manager,
......
......@@ -35,7 +35,7 @@ from sglang.srt.lora.utils import (
get_normalized_target_modules,
get_target_module_name,
)
from sglang.srt.managers.io_struct import LoRAUpdateResult
from sglang.srt.managers.io_struct import LoRAUpdateOutput
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import replace_submodule
......@@ -107,8 +107,8 @@ class LoRAManager:
def create_lora_update_result(
self, success: bool, error_message: str = ""
) -> LoRAUpdateResult:
return LoRAUpdateResult(
) -> LoRAUpdateOutput:
return LoRAUpdateOutput(
success=success,
error_message=error_message,
loaded_adapters={
......@@ -117,7 +117,7 @@ class LoRAManager:
},
)
def load_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateResult:
def load_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateOutput:
"""
Load a single LoRA adapter from the specified path.
......@@ -174,7 +174,7 @@ class LoRAManager:
"`--max-loras-per-batch` or load it as unpinned LoRA adapters."
)
def unload_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateResult:
def unload_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateOutput:
"""
Unload LoRA adapters by their names. This will remove the adapters from the memory pool and
delete the corresponding LoRA modules.
......
......@@ -26,11 +26,11 @@ import zmq
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.managers.io_struct import (
BatchEmbeddingOut,
BatchEmbeddingOutput,
BatchMultimodalDecodeReq,
BatchMultimodalOut,
BatchStrOut,
BatchTokenIDOut,
BatchMultimodalOutput,
BatchStrOutput,
BatchTokenIDOutput,
FreezeGCReq,
MultiTokenizerRegisterReq,
)
......@@ -101,8 +101,8 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
self._request_dispatcher = TypeBasedDispatcher(
[
(BatchEmbeddingOut, self.handle_batch_embedding_out),
(BatchTokenIDOut, self.handle_batch_token_id_out),
(BatchEmbeddingOutput, self.handle_batch_embedding_out),
(BatchTokenIDOutput, self.handle_batch_token_id_out),
(BatchMultimodalDecodeReq, self.handle_multimodal_decode_req),
(MultiTokenizerRegisterReq, lambda x: x),
(FreezeGCReq, self.handle_freeze_gc_req),
......@@ -145,11 +145,11 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
return output[:-1]
return output
def handle_batch_embedding_out(self, recv_obj: BatchEmbeddingOut):
def handle_batch_embedding_out(self, recv_obj: BatchEmbeddingOutput):
# If it is embedding model, no detokenization is needed.
return recv_obj
def handle_batch_token_id_out(self, recv_obj: BatchTokenIDOut):
def handle_batch_token_id_out(self, recv_obj: BatchTokenIDOutput):
bs = len(recv_obj.rids)
# Initialize decode status
......@@ -224,7 +224,7 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
s.sent_offset = len(output_str)
output_strs.append(incremental_output)
return BatchStrOut(
return BatchStrOutput(
rids=recv_obj.rids,
finished_reasons=recv_obj.finished_reasons,
output_strs=output_strs,
......@@ -252,7 +252,7 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
def handle_multimodal_decode_req(self, recv_obj: BatchMultimodalDecodeReq):
outputs = self.tokenizer.detokenize(recv_obj)
return BatchMultimodalOut(
return BatchMultimodalOutput(
rids=recv_obj.rids,
finished_reasons=recv_obj.finished_reasons,
outputs=outputs,
......
......@@ -18,6 +18,7 @@ processes (TokenizerManager, DetokenizerManager, Scheduler).
import copy
import uuid
from abc import ABC
from dataclasses import dataclass, field
from enum import Enum
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
......@@ -36,10 +37,32 @@ else:
# Parameters for a session
@dataclass
class BaseReq(ABC):
rid: Optional[Union[str, List[str]]] = field(default=None, kw_only=True)
def regenerate_rid(self):
"""Generate a new request ID and return it."""
if isinstance(self.rid, list):
self.rid = [uuid.uuid4().hex for _ in range(len(self.rid))]
else:
self.rid = uuid.uuid4().hex
return self.rid
@dataclass
class BaseBatchReq(ABC):
rids: Optional[List[str]] = field(default=None, kw_only=True)
def regenerate_rids(self):
"""Generate new request IDs and return them."""
self.rids = [uuid.uuid4().hex for _ in range(len(self.rids))]
return self.rids
@dataclass
class SessionParams:
id: Optional[str] = None
rid: Optional[str] = None
offset: Optional[int] = None
replace: Optional[bool] = None
drop_previous_output: Optional[bool] = None
......@@ -63,7 +86,7 @@ MultimodalDataInputFormat = Union[
@dataclass
class GenerateReqInput:
class GenerateReqInput(BaseReq):
# The input prompt. It can be a single prompt or a batch of prompts.
text: Optional[Union[List[str], str]] = None
# The token ids for text; one can specify either text or input_ids
......@@ -83,8 +106,6 @@ class GenerateReqInput:
audio_data: Optional[MultimodalDataInputFormat] = None
# The sampling_params. See descriptions below.
sampling_params: Optional[Union[List[Dict], Dict]] = None
# The request id.
rid: Optional[Union[List[str], str]] = None
# Whether to return logprobs.
return_logprob: Optional[Union[List[bool], bool]] = None
# If return logprobs, the start location in the prompt for returning logprobs.
......@@ -491,11 +512,6 @@ class GenerateReqInput:
):
raise ValueError("Session params must be a dict or a list of dicts.")
def regenerate_rid(self):
"""Generate a new request ID and return it."""
self.rid = uuid.uuid4().hex
return self.rid
def __getitem__(self, i):
return GenerateReqInput(
text=self.text[i] if self.text is not None else None,
......@@ -558,9 +574,7 @@ class GenerateReqInput:
@dataclass
class TokenizedGenerateReqInput:
# The request id
rid: str
class TokenizedGenerateReqInput(BaseReq):
# The input text
input_text: str
# The input token ids
......@@ -625,7 +639,7 @@ class TokenizedGenerateReqInput:
@dataclass
class BatchTokenizedGenerateReqInput:
class BatchTokenizedGenerateReqInput(BaseBatchReq):
# The batch of tokenized requests
batch: List[TokenizedGenerateReqInput]
......@@ -640,7 +654,7 @@ class BatchTokenizedGenerateReqInput:
@dataclass
class EmbeddingReqInput:
class EmbeddingReqInput(BaseReq):
# The input prompt. It can be a single prompt or a batch of prompts.
text: Optional[Union[List[List[str]], List[str], str]] = None
# The image input. It can be an image instance, file name, URL, or base64 encoded string.
......@@ -656,8 +670,6 @@ class EmbeddingReqInput:
audio_data: Optional[MultimodalDataInputFormat] = None
# The token ids for text; one can either specify text or input_ids.
input_ids: Optional[Union[List[List[int]], List[int]]] = None
# The request id.
rid: Optional[Union[List[str], str]] = None
# Dummy sampling params for compatibility
sampling_params: Optional[Union[List[Dict], Dict]] = None
# Dummy input embeds for compatibility
......@@ -728,10 +740,6 @@ class EmbeddingReqInput:
for i in range(self.batch_size):
self.sampling_params[i]["max_new_tokens"] = 0
def regenerate_rid(self):
self.rid = uuid.uuid4().hex
return self.rid
def contains_mm_input(self) -> bool:
return (
has_valid_data(self.image_data)
......@@ -760,9 +768,7 @@ class EmbeddingReqInput:
@dataclass
class TokenizedEmbeddingReqInput:
# The request id
rid: str
class TokenizedEmbeddingReqInput(BaseReq):
# The input text
input_text: str
# The input token ids
......@@ -780,7 +786,7 @@ class TokenizedEmbeddingReqInput:
@dataclass
class BatchTokenizedEmbeddingReqInput:
class BatchTokenizedEmbeddingReqInput(BaseBatchReq):
# The batch of tokenized embedding requests
batch: List[TokenizedEmbeddingReqInput]
......@@ -795,9 +801,7 @@ class BatchTokenizedEmbeddingReqInput:
@dataclass
class BatchTokenIDOut:
# The request id
rids: List[str]
class BatchTokenIDOutput(BaseBatchReq):
# The finish reason
finished_reasons: List[BaseFinishReason]
# For incremental decoding
......@@ -842,7 +846,7 @@ class BatchTokenIDOut:
@dataclass
class BatchMultimodalDecodeReq:
class BatchMultimodalDecodeReq(BaseBatchReq):
decoded_ids: List[int]
input_token_logprobs_val: List[float]
input_token_logprobs_idx: List[int]
......@@ -854,8 +858,6 @@ class BatchMultimodalDecodeReq:
image_resolutions: List[List[int]]
resize_image_resolutions: List[List[int]]
# The request id
rids: List[str]
finished_reasons: List[BaseFinishReason]
# Token counts
......@@ -871,9 +873,7 @@ class BatchMultimodalDecodeReq:
@dataclass
class BatchStrOut:
# The request id
rids: List[str]
class BatchStrOutput(BaseBatchReq):
# The finish reason
finished_reasons: List[dict]
# The output decoded strings
......@@ -909,9 +909,7 @@ class BatchStrOut:
@dataclass
class BatchMultimodalOut:
# The request id
rids: List[str]
class BatchMultimodalOutput(BaseBatchReq):
# The finish reason
finished_reasons: List[dict]
decoded_ids: List[List[int]]
......@@ -936,9 +934,7 @@ class BatchMultimodalOut:
@dataclass
class BatchEmbeddingOut:
# The request id
rids: List[str]
class BatchEmbeddingOutput(BaseBatchReq):
# The finish reason
finished_reasons: List[BaseFinishReason]
# The output embedding
......@@ -952,27 +948,27 @@ class BatchEmbeddingOut:
@dataclass
class ClearHiCacheReqInput:
class ClearHiCacheReqInput(BaseReq):
pass
@dataclass
class ClearHiCacheReqOutput:
class ClearHiCacheReqOutput(BaseReq):
success: bool
@dataclass
class FlushCacheReqInput:
class FlushCacheReqInput(BaseReq):
pass
@dataclass
class FlushCacheReqOutput:
class FlushCacheReqOutput(BaseReq):
success: bool
@dataclass
class UpdateWeightFromDiskReqInput:
class UpdateWeightFromDiskReqInput(BaseReq):
# The model path with the new weights
model_path: str
# The format to load the weights
......@@ -990,7 +986,7 @@ class UpdateWeightFromDiskReqInput:
@dataclass
class UpdateWeightFromDiskReqOutput:
class UpdateWeightFromDiskReqOutput(BaseReq):
success: bool
message: str
# Number of paused requests during weight sync.
......@@ -998,7 +994,7 @@ class UpdateWeightFromDiskReqOutput:
@dataclass
class UpdateWeightsFromDistributedReqInput:
class UpdateWeightsFromDistributedReqInput(BaseReq):
names: List[str]
dtypes: List[str]
shapes: List[List[int]]
......@@ -1013,13 +1009,13 @@ class UpdateWeightsFromDistributedReqInput:
@dataclass
class UpdateWeightsFromDistributedReqOutput:
class UpdateWeightsFromDistributedReqOutput(BaseReq):
success: bool
message: str
@dataclass
class UpdateWeightsFromTensorReqInput:
class UpdateWeightsFromTensorReqInput(BaseReq):
"""Update model weights from tensor input.
- Tensors are serialized for transmission
......@@ -1038,13 +1034,13 @@ class UpdateWeightsFromTensorReqInput:
@dataclass
class UpdateWeightsFromTensorReqOutput:
class UpdateWeightsFromTensorReqOutput(BaseReq):
success: bool
message: str
@dataclass
class InitWeightsSendGroupForRemoteInstanceReqInput:
class InitWeightsSendGroupForRemoteInstanceReqInput(BaseReq):
# The master address
master_address: str
# The ports for each rank's communication group
......@@ -1060,13 +1056,13 @@ class InitWeightsSendGroupForRemoteInstanceReqInput:
@dataclass
class InitWeightsSendGroupForRemoteInstanceReqOutput:
class InitWeightsSendGroupForRemoteInstanceReqOutput(BaseReq):
success: bool
message: str
@dataclass
class SendWeightsToRemoteInstanceReqInput:
class SendWeightsToRemoteInstanceReqInput(BaseReq):
# The master address
master_address: str
# The ports for each rank's communication group
......@@ -1076,13 +1072,13 @@ class SendWeightsToRemoteInstanceReqInput:
@dataclass
class SendWeightsToRemoteInstanceReqOutput:
class SendWeightsToRemoteInstanceReqOutput(BaseReq):
success: bool
message: str
@dataclass
class InitWeightsUpdateGroupReqInput:
class InitWeightsUpdateGroupReqInput(BaseReq):
# The master address
master_address: str
# The master port
......@@ -1098,24 +1094,24 @@ class InitWeightsUpdateGroupReqInput:
@dataclass
class InitWeightsUpdateGroupReqOutput:
class InitWeightsUpdateGroupReqOutput(BaseReq):
success: bool
message: str
@dataclass
class DestroyWeightsUpdateGroupReqInput:
class DestroyWeightsUpdateGroupReqInput(BaseReq):
group_name: str = "weight_update_group"
@dataclass
class DestroyWeightsUpdateGroupReqOutput:
class DestroyWeightsUpdateGroupReqOutput(BaseReq):
success: bool
message: str
@dataclass
class UpdateWeightVersionReqInput:
class UpdateWeightVersionReqInput(BaseReq):
# The new weight version
new_version: str
# Whether to abort all running requests before updating
......@@ -1123,89 +1119,87 @@ class UpdateWeightVersionReqInput:
@dataclass
class GetWeightsByNameReqInput:
class GetWeightsByNameReqInput(BaseReq):
name: str
truncate_size: int = 100
@dataclass
class GetWeightsByNameReqOutput:
class GetWeightsByNameReqOutput(BaseReq):
parameter: list
@dataclass
class ReleaseMemoryOccupationReqInput:
class ReleaseMemoryOccupationReqInput(BaseReq):
# Optional tags to identify the memory region, which is primarily used for RL
# Currently we only support `weights` and `kv_cache`
tags: Optional[List[str]] = None
@dataclass
class ReleaseMemoryOccupationReqOutput:
class ReleaseMemoryOccupationReqOutput(BaseReq):
pass
@dataclass
class ResumeMemoryOccupationReqInput:
class ResumeMemoryOccupationReqInput(BaseReq):
# Optional tags to identify the memory region, which is primarily used for RL
# Currently we only support `weights` and `kv_cache`
tags: Optional[List[str]] = None
@dataclass
class ResumeMemoryOccupationReqOutput:
class ResumeMemoryOccupationReqOutput(BaseReq):
pass
@dataclass
class SlowDownReqInput:
class SlowDownReqInput(BaseReq):
forward_sleep_time: Optional[float]
@dataclass
class SlowDownReqOutput:
class SlowDownReqOutput(BaseReq):
pass
@dataclass
class AbortReq:
# The request id
rid: str = ""
class AbortReq(BaseReq):
# Whether to abort all requests
abort_all: bool = False
# The finished reason data
finished_reason: Optional[Dict[str, Any]] = None
abort_reason: Optional[str] = None
# used in MultiTokenzierManager mode
rids: Optional[Union[List[str], str]] = None
def __post_init__(self):
self.rids = self.rid
# FIXME: This is a hack to keep the same with the old code
if self.rid is None:
self.rid = ""
@dataclass
class GetInternalStateReq:
class GetInternalStateReq(BaseReq):
pass
@dataclass
class GetInternalStateReqOutput:
class GetInternalStateReqOutput(BaseReq):
internal_state: Dict[Any, Any]
@dataclass
class SetInternalStateReq:
class SetInternalStateReq(BaseReq):
server_args: Dict[str, Any]
@dataclass
class SetInternalStateReqOutput:
class SetInternalStateReqOutput(BaseReq):
updated: bool
server_args: Dict[str, Any]
@dataclass
class ProfileReqInput:
class ProfileReqInput(BaseReq):
# The output directory
output_dir: Optional[str] = None
# If set, it profile as many as this number of steps.
......@@ -1225,7 +1219,7 @@ class ProfileReqType(Enum):
@dataclass
class ProfileReq:
class ProfileReq(BaseReq):
type: ProfileReqType
output_dir: Optional[str] = None
start_step: Optional[int] = None
......@@ -1238,18 +1232,18 @@ class ProfileReq:
@dataclass
class ProfileReqOutput:
class ProfileReqOutput(BaseReq):
success: bool
message: str
@dataclass
class FreezeGCReq:
class FreezeGCReq(BaseReq):
pass
@dataclass
class ConfigureLoggingReq:
class ConfigureLoggingReq(BaseReq):
log_requests: Optional[bool] = None
log_requests_level: Optional[int] = None
dump_requests_folder: Optional[str] = None
......@@ -1258,35 +1252,39 @@ class ConfigureLoggingReq:
@dataclass
class OpenSessionReqInput:
class OpenSessionReqInput(BaseReq):
capacity_of_str_len: int
session_id: Optional[str] = None
@dataclass
class CloseSessionReqInput:
class CloseSessionReqInput(BaseReq):
session_id: str
@dataclass
class OpenSessionReqOutput:
class OpenSessionReqOutput(BaseReq):
session_id: Optional[str]
success: bool
@dataclass
class HealthCheckOutput:
class HealthCheckOutput(BaseReq):
pass
class ExpertDistributionReq(Enum):
class ExpertDistributionReqType(Enum):
START_RECORD = 1
STOP_RECORD = 2
DUMP_RECORD = 3
class ExpertDistributionReq(BaseReq):
action: ExpertDistributionReqType
@dataclass
class ExpertDistributionReqOutput:
class ExpertDistributionReqOutput(BaseReq):
pass
......@@ -1304,7 +1302,7 @@ class Tool:
@dataclass
class ParseFunctionCallReq:
class ParseFunctionCallReq(BaseReq):
text: str # The text to parse.
tools: List[Tool] = field(
default_factory=list
......@@ -1315,31 +1313,31 @@ class ParseFunctionCallReq:
@dataclass
class SeparateReasoningReqInput:
class SeparateReasoningReqInput(BaseReq):
text: str # The text to parse.
reasoning_parser: str # Specify the parser type, e.g., "deepseek-r1".
@dataclass
class VertexGenerateReqInput:
class VertexGenerateReqInput(BaseReq):
instances: List[dict]
parameters: Optional[dict] = None
@dataclass
class RpcReqInput:
class RpcReqInput(BaseReq):
method: str
parameters: Optional[Dict] = None
@dataclass
class RpcReqOutput:
class RpcReqOutput(BaseReq):
success: bool
message: str
@dataclass
class LoadLoRAAdapterReqInput:
class LoadLoRAAdapterReqInput(BaseReq):
# The name of the lora module to newly loaded.
lora_name: str
# The path of loading.
......@@ -1359,7 +1357,7 @@ class LoadLoRAAdapterReqInput:
@dataclass
class UnloadLoRAAdapterReqInput:
class UnloadLoRAAdapterReqInput(BaseReq):
# The name of lora module to unload.
lora_name: str
# The unique identifier for the LoRA adapter, which automatically generated in the `TokenizerManager`.
......@@ -1373,23 +1371,23 @@ class UnloadLoRAAdapterReqInput:
@dataclass
class LoRAUpdateResult:
class LoRAUpdateOutput(BaseReq):
success: bool
error_message: Optional[str] = None
loaded_adapters: Optional[Dict[str, LoRARef]] = None
LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult
LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateOutput
@dataclass
class MultiTokenizerRegisterReq:
rids: Optional[Union[List[str], str]] = None
class MultiTokenizerRegisterReq(BaseBatchReq):
ipc_name: Optional[str] = None
@dataclass
class MultiTokenizerWrapper:
# FIXME(lsyin): remove this
worker_id: int
obj: Optional[Any] = None
......@@ -1400,17 +1398,17 @@ class BlockReqType(Enum):
@dataclass
class BlockReqInput:
class BlockReqInput(BaseReq):
type: BlockReqType
@dataclass
class GetLoadReqInput:
class GetLoadReqInput(BaseReq):
pass
@dataclass
class GetLoadReqOutput:
class GetLoadReqOutput(BaseReq):
dp_rank: int
num_reqs: int
num_waiting_reqs: int
......@@ -1418,5 +1416,31 @@ class GetLoadReqOutput:
@dataclass
class WatchLoadUpdateReq:
class WatchLoadUpdateReq(BaseReq):
loads: List[GetLoadReqOutput]
def _check_all_req_types():
"""A helper function to check all request types are defined in this file."""
import inspect
import sys
all_classes = inspect.getmembers(sys.modules[__name__], inspect.isclass)
for class_type in all_classes:
# check its name
name = class_type[0]
is_io_struct = (
name.endswith("Req") or name.endswith("Input") or name.endswith("Output")
)
is_base_req = issubclass(class_type[1], BaseReq) or issubclass(
class_type[1], BaseBatchReq
)
if is_io_struct and not is_base_req:
raise ValueError(f"{name} is not a subclass of BaseReq or BaseBatchReq.")
if is_base_req and not is_io_struct:
raise ValueError(
f"{name} is a subclass of BaseReq but not follow the naming convention."
)
_check_all_req_types()
......@@ -11,7 +11,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""MultiTokenizerMixin is a class that provides nesscary methods for MultiTokenizerManager and DetokenizerManager."""
"""Mixin class and utils for multi-http-worker mode"""
import asyncio
import logging
import multiprocessing as multiprocessing
......@@ -30,10 +30,10 @@ import zmq.asyncio
from sglang.srt.disaggregation.utils import DisaggregationMode, TransferBackend
from sglang.srt.managers.disagg_service import start_disagg_service
from sglang.srt.managers.io_struct import (
BatchEmbeddingOut,
BatchMultimodalOut,
BatchStrOut,
BatchTokenIDOut,
BatchEmbeddingOutput,
BatchMultimodalOutput,
BatchStrOutput,
BatchTokenIDOutput,
MultiTokenizerRegisterReq,
MultiTokenizerWrapper,
)
......@@ -83,8 +83,8 @@ class SocketMapping:
def _handle_output_by_index(output, i):
"""NOTE: A maintainable method is better here."""
if isinstance(output, BatchTokenIDOut):
new_output = BatchTokenIDOut(
if isinstance(output, BatchTokenIDOutput):
new_output = BatchTokenIDOutput(
rids=[output.rids[i]],
finished_reasons=(
[output.finished_reasons[i]]
......@@ -198,8 +198,8 @@ def _handle_output_by_index(output, i):
placeholder_tokens_idx=None,
placeholder_tokens_val=None,
)
elif isinstance(output, BatchEmbeddingOut):
new_output = BatchEmbeddingOut(
elif isinstance(output, BatchEmbeddingOutput):
new_output = BatchEmbeddingOutput(
rids=[output.rids[i]],
finished_reasons=(
[output.finished_reasons[i]]
......@@ -216,8 +216,8 @@ def _handle_output_by_index(output, i):
placeholder_tokens_idx=None,
placeholder_tokens_val=None,
)
elif isinstance(output, BatchStrOut):
new_output = BatchStrOut(
elif isinstance(output, BatchStrOutput):
new_output = BatchStrOutput(
rids=[output.rids[i]],
finished_reasons=(
[output.finished_reasons[i]]
......@@ -314,8 +314,8 @@ def _handle_output_by_index(output, i):
placeholder_tokens_idx=None,
placeholder_tokens_val=None,
)
elif isinstance(output, BatchMultimodalOut):
new_output = BatchMultimodalOut(
elif isinstance(output, BatchMultimodalOutput):
new_output = BatchMultimodalOutput(
rids=[output.rids[i]],
finished_reasons=(
[output.finished_reasons[i]]
......@@ -343,7 +343,7 @@ def _handle_output_by_index(output, i):
class MultiHttpWorkerDetokenizerMixin:
"""Mixin class for MultiTokenizerManager and DetokenizerManager"""
"""Mixin class for DetokenizerManager"""
def get_worker_ids_from_req_rids(self, rids):
if isinstance(rids, list):
......@@ -386,7 +386,7 @@ class MultiHttpWorkerDetokenizerMixin:
class MultiTokenizerRouter:
"""A router to receive requests from MultiTokenizerManager"""
"""A router to receive requests from TokenizerWorker"""
def __init__(
self,
......@@ -454,8 +454,8 @@ class MultiTokenizerRouter:
self.socket_mapping.send_output(worker_id, new_recv_obj)
class MultiTokenizerManager(TokenizerManager):
"""Multi Process Tokenizer Manager that tokenizes the text."""
class TokenizerWorker(TokenizerManager):
"""Tokenizer Worker in multi-http-worker mode"""
def __init__(
self,
......
......@@ -78,6 +78,7 @@ from sglang.srt.managers.io_struct import (
DestroyWeightsUpdateGroupReqInput,
ExpertDistributionReq,
ExpertDistributionReqOutput,
ExpertDistributionReqType,
FlushCacheReqInput,
FlushCacheReqOutput,
FreezeGCReq,
......@@ -1487,12 +1488,12 @@ class Scheduler(
req.priority = -sys.maxsize - 1
elif not self.enable_priority_scheduling and req.priority is not None:
abort_req = AbortReq(
req.rid,
finished_reason={
"type": "abort",
"status_code": HTTPStatus.SERVICE_UNAVAILABLE,
"message": "Using priority is disabled for this server. Please send a new request without a priority.",
},
rid=req.rid,
)
self.send_to_tokenizer.send_pyobj(abort_req)
......@@ -1528,12 +1529,12 @@ class Scheduler(
self.send_to_tokenizer.send_pyobj(
AbortReq(
req_to_abort.rid,
finished_reason={
"type": "abort",
"status_code": HTTPStatus.SERVICE_UNAVAILABLE,
"message": message,
},
rid=req_to_abort.rid,
)
)
return req_to_abort.rid == recv_req.rid
......@@ -2005,7 +2006,7 @@ class Scheduler(
self.new_token_ratio = new_token_ratio
for req in reqs_to_abort:
self.send_to_tokenizer.send_pyobj(
AbortReq(req.rid, abort_reason=req.to_abort_message)
AbortReq(abort_reason=req.to_abort_message, rid=req.rid)
)
logger.info(
......@@ -2575,7 +2576,7 @@ class Scheduler(
if self.enable_hicache_storage:
# to release prefetch events associated with the request
self.tree_cache.release_aborted_request(req.rid)
self.send_to_tokenizer.send_pyobj(AbortReq(req.rid))
self.send_to_tokenizer.send_pyobj(AbortReq(rid=req.rid))
# For disaggregation decode mode, the request in the waiting queue has KV cache allocated.
if self.disaggregation_mode == DisaggregationMode.DECODE:
self.tree_cache.cache_finished_req(req)
......@@ -2687,11 +2688,12 @@ class Scheduler(
return SlowDownReqOutput()
def expert_distribution_handle(self, recv_req: ExpertDistributionReq):
if recv_req == ExpertDistributionReq.START_RECORD:
action = recv_req.action
if action == ExpertDistributionReqType.START_RECORD:
get_global_expert_distribution_recorder().start_record()
elif recv_req == ExpertDistributionReq.STOP_RECORD:
elif action == ExpertDistributionReqType.STOP_RECORD:
get_global_expert_distribution_recorder().stop_record()
elif recv_req == ExpertDistributionReq.DUMP_RECORD:
elif action == ExpertDistributionReqType.DUMP_RECORD:
get_global_expert_distribution_recorder().dump_record()
else:
raise ValueError(f"Unrecognized ExpertDistributionReq value: {recv_req=}")
......@@ -2774,7 +2776,8 @@ class IdleSleeper:
def is_health_check_generate_req(recv_req):
return getattr(recv_req, "rid", "").startswith("HEALTH_CHECK")
rid = getattr(recv_req, "rid", None)
return rid is not None and rid.startswith("HEALTH_CHECK")
def is_work_request(recv_req):
......
......@@ -9,7 +9,11 @@ import torch
from sglang.srt.disaggregation.utils import DisaggregationMode
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.io_struct import AbortReq, BatchEmbeddingOut, BatchTokenIDOut
from sglang.srt.managers.io_struct import (
AbortReq,
BatchEmbeddingOutput,
BatchTokenIDOutput,
)
from sglang.srt.managers.schedule_batch import BaseFinishReason, Req, ScheduleBatch
if TYPE_CHECKING:
......@@ -140,7 +144,7 @@ class SchedulerOutputProcessorMixin:
logger.error(
f"Grammar accept_token failed for req {req.rid} with token {next_token_id}: {e}"
)
self.abort_request(AbortReq(req.rid))
self.abort_request(AbortReq(rid=req.rid))
req.grammar.finished = req.finished()
else:
# being chunked reqs' prefill is not finished
......@@ -292,7 +296,7 @@ class SchedulerOutputProcessorMixin:
logger.error(
f"Grammar accept_token failed for req {req.rid} with token {next_token_id}: {e}"
)
self.abort_request(AbortReq(req.rid))
self.abort_request(AbortReq(rid=req.rid))
req.grammar.finished = req.finished()
self.set_next_batch_sampling_info_done(batch)
......@@ -714,8 +718,7 @@ class SchedulerOutputProcessorMixin:
return
self.send_to_detokenizer.send_pyobj(
BatchTokenIDOut(
rids,
BatchTokenIDOutput(
finished_reasons,
decoded_texts,
decode_ids_list,
......@@ -741,6 +744,7 @@ class SchedulerOutputProcessorMixin:
output_token_ids_logprobs_val,
output_token_ids_logprobs_idx,
output_hidden_states,
rids=rids,
placeholder_tokens_idx=None,
placeholder_tokens_val=None,
)
......@@ -761,12 +765,12 @@ class SchedulerOutputProcessorMixin:
prompt_tokens.append(len(req.origin_input_ids))
cached_tokens.append(req.cached_tokens)
self.send_to_detokenizer.send_pyobj(
BatchEmbeddingOut(
rids,
BatchEmbeddingOutput(
finished_reasons,
embeddings,
prompt_tokens,
cached_tokens,
rids=rids,
placeholder_tokens_idx=None,
placeholder_tokens_val=None,
)
......
......@@ -30,6 +30,7 @@ from sglang.srt.managers.io_struct import (
DestroyWeightsUpdateGroupReqOutput,
ExpertDistributionReq,
ExpertDistributionReqOutput,
ExpertDistributionReqType,
FlushCacheReqInput,
FlushCacheReqOutput,
GetInternalStateReq,
......@@ -44,7 +45,7 @@ from sglang.srt.managers.io_struct import (
InitWeightsUpdateGroupReqOutput,
LoadLoRAAdapterReqInput,
LoadLoRAAdapterReqOutput,
LoRAUpdateResult,
LoRAUpdateOutput,
MultiTokenizerWrapper,
OpenSessionReqInput,
ProfileReq,
......@@ -276,7 +277,7 @@ class TokenizerCommunicatorMixin:
self.expert_distribution_communicator.handle_recv,
),
(
LoRAUpdateResult,
LoRAUpdateOutput,
self.update_lora_adapter_communicator.handle_recv,
),
(
......@@ -335,15 +336,18 @@ class TokenizerCommunicatorMixin:
async def start_expert_distribution_record(self: TokenizerManager):
self.auto_create_handle_loop()
await self.expert_distribution_communicator(ExpertDistributionReq.START_RECORD)
req = ExpertDistributionReq(action=ExpertDistributionReqType.START_RECORD)
await self.expert_distribution_communicator(req)
async def stop_expert_distribution_record(self: TokenizerManager):
self.auto_create_handle_loop()
await self.expert_distribution_communicator(ExpertDistributionReq.STOP_RECORD)
req = ExpertDistributionReq(action=ExpertDistributionReqType.STOP_RECORD)
await self.expert_distribution_communicator(req)
async def dump_expert_distribution_record(self: TokenizerManager):
self.auto_create_handle_loop()
await self.expert_distribution_communicator(ExpertDistributionReq.DUMP_RECORD)
req = ExpertDistributionReq(action=ExpertDistributionReqType.DUMP_RECORD)
await self.expert_distribution_communicator(req)
async def init_weights_update_group(
self: TokenizerManager,
......
......@@ -48,18 +48,17 @@ from sglang.srt.hf_transformers_utils import (
get_tokenizer,
get_tokenizer_from_processor,
)
from sglang.srt.lora.lora_registry import LoRARef, LoRARegistry
from sglang.srt.lora.lora_registry import LoRARegistry
from sglang.srt.managers.async_dynamic_batch_tokenizer import AsyncDynamicbatchTokenizer
from sglang.srt.managers.disagg_service import start_disagg_service
from sglang.srt.managers.io_struct import (
AbortReq,
BatchEmbeddingOut,
BatchMultimodalOut,
BatchStrOut,
BatchTokenIDOut,
BatchEmbeddingOutput,
BatchMultimodalOutput,
BatchStrOutput,
BatchTokenIDOutput,
BatchTokenizedEmbeddingReqInput,
BatchTokenizedGenerateReqInput,
CloseSessionReqInput,
ConfigureLoggingReq,
EmbeddingReqInput,
FreezeGCReq,
......@@ -67,7 +66,6 @@ from sglang.srt.managers.io_struct import (
GetLoadReqInput,
HealthCheckOutput,
MultiTokenizerWrapper,
OpenSessionReqInput,
OpenSessionReqOutput,
SessionParams,
TokenizedEmbeddingReqInput,
......@@ -341,10 +339,10 @@ class TokenizerManager(TokenizerCommunicatorMixin):
[
(
(
BatchStrOut,
BatchEmbeddingOut,
BatchTokenIDOut,
BatchMultimodalOut,
BatchStrOutput,
BatchEmbeddingOutput,
BatchTokenIDOutput,
BatchMultimodalOutput,
),
self._handle_batch_output,
),
......@@ -716,7 +714,6 @@ class TokenizerManager(TokenizerCommunicatorMixin):
)
tokenized_obj = TokenizedGenerateReqInput(
obj.rid,
input_text,
input_ids,
mm_inputs,
......@@ -726,6 +723,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
obj.top_logprobs_num,
obj.token_ids_logprob,
obj.stream,
rid=obj.rid,
bootstrap_host=obj.bootstrap_host,
bootstrap_port=obj.bootstrap_port,
bootstrap_room=obj.bootstrap_room,
......@@ -740,12 +738,12 @@ class TokenizerManager(TokenizerCommunicatorMixin):
)
elif isinstance(obj, EmbeddingReqInput):
tokenized_obj = TokenizedEmbeddingReqInput(
obj.rid,
input_text,
input_ids,
mm_inputs,
token_type_ids,
sampling_params,
rid=obj.rid,
priority=obj.priority,
)
......@@ -1038,7 +1036,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
def abort_request(self, rid: str = "", abort_all: bool = False):
if not abort_all and rid not in self.rid_to_state:
return
req = AbortReq(rid, abort_all)
req = AbortReq(rid=rid, abort_all=abort_all)
self.send_to_scheduler.send_pyobj(req)
if self.enable_metrics:
# TODO: also use custom_labels from the request
......@@ -1303,7 +1301,10 @@ class TokenizerManager(TokenizerCommunicatorMixin):
def _handle_batch_output(
self,
recv_obj: Union[
BatchStrOut, BatchEmbeddingOut, BatchMultimodalOut, BatchTokenIDOut
BatchStrOutput,
BatchEmbeddingOutput,
BatchMultimodalOutput,
BatchTokenIDOutput,
],
):
for i, rid in enumerate(recv_obj.rids):
......@@ -1337,7 +1338,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
i,
)
if not isinstance(recv_obj, BatchEmbeddingOut):
if not isinstance(recv_obj, BatchEmbeddingOutput):
meta_info.update(
{
"completion_tokens": recv_obj.completion_tokens[i],
......@@ -1348,7 +1349,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
if getattr(recv_obj, "output_hidden_states", None):
meta_info["hidden_states"] = recv_obj.output_hidden_states[i]
if isinstance(recv_obj, BatchStrOut):
if isinstance(recv_obj, BatchStrOutput):
state.text += recv_obj.output_strs[i]
if state.obj.stream:
state.output_ids.extend(recv_obj.output_ids[i])
......@@ -1363,7 +1364,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
"output_ids": output_token_ids,
"meta_info": meta_info,
}
elif isinstance(recv_obj, BatchTokenIDOut):
elif isinstance(recv_obj, BatchTokenIDOutput):
if self.server_args.stream_output and state.obj.stream:
state.output_ids.extend(recv_obj.output_ids[i])
output_token_ids = state.output_ids[state.last_output_offset :]
......@@ -1376,10 +1377,10 @@ class TokenizerManager(TokenizerCommunicatorMixin):
"output_ids": output_token_ids,
"meta_info": meta_info,
}
elif isinstance(recv_obj, BatchMultimodalOut):
elif isinstance(recv_obj, BatchMultimodalOutput):
raise NotImplementedError("BatchMultimodalOut not implemented")
else:
assert isinstance(recv_obj, BatchEmbeddingOut)
assert isinstance(recv_obj, BatchEmbeddingOutput)
out_dict = {
"embedding": recv_obj.embeddings[i],
"meta_info": meta_info,
......@@ -1418,7 +1419,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
top_logprobs_num: int,
token_ids_logprob: List[int],
return_text_in_logprobs: bool,
recv_obj: BatchStrOut,
recv_obj: BatchStrOutput,
recv_obj_index: int,
):
if recv_obj.input_token_logprobs_val is None:
......@@ -1536,7 +1537,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
ret.append(None)
return ret
def collect_metrics(self, state: ReqState, recv_obj: BatchStrOut, i: int):
def collect_metrics(self, state: ReqState, recv_obj: BatchStrOutput, i: int):
completion_tokens = (
recv_obj.completion_tokens[i]
if getattr(recv_obj, "completion_tokens", None)
......@@ -1632,7 +1633,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
asyncio.create_task(asyncio.to_thread(background_task))
def _handle_abort_req(self, recv_obj):
def _handle_abort_req(self, recv_obj: AbortReq):
if is_health_check_generate_req(recv_obj):
return
state = self.rid_to_state[recv_obj.rid]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment