Unverified Commit 260fe755 authored by Zhengke Zhou's avatar Zhengke Zhou Committed by GitHub
Browse files

Simplify multi-tokenizer (#11295)


Signed-off-by: default avatarzhengkezhou1 <madzhou1@gmail.com>
Co-authored-by: default avatarLiangsheng Yin <lsyincs@gmail.com>
parent dbb16bed
...@@ -149,15 +149,14 @@ def set_global_state(global_state: _GlobalState): ...@@ -149,15 +149,14 @@ def set_global_state(global_state: _GlobalState):
async def init_multi_tokenizer() -> ServerArgs: async def init_multi_tokenizer() -> ServerArgs:
"""Read args information from shm and init tokenizer manager for current process""" """Read args information from shm and init tokenizer manager for current process"""
pid = os.getpid()
main_pid = get_main_process_id()
logger.info(f"current worker_id: {pid}, main processID: {main_pid}")
# Read configuration from shared memory # Read configuration from shared memory
main_pid = get_main_process_id()
port_args, server_args, scheduler_info = read_from_shared_memory( port_args, server_args, scheduler_info = read_from_shared_memory(
f"multi_tokenizer_args_{main_pid}" f"multi_tokenizer_args_{main_pid}"
) )
server_args: ServerArgs server_args: ServerArgs
port_args: PortArgs
# API key authentication is not supported in multi-tokenizer mode # API key authentication is not supported in multi-tokenizer mode
assert ( assert (
...@@ -167,6 +166,10 @@ async def init_multi_tokenizer() -> ServerArgs: ...@@ -167,6 +166,10 @@ async def init_multi_tokenizer() -> ServerArgs:
port_args.tokenizer_ipc_name = ( port_args.tokenizer_ipc_name = (
f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}" f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}"
) )
logger.info(
f"Start multi-tokenizer worker process {os.getpid()}, "
f"ipc_name={port_args.tokenizer_ipc_name}"
)
# Launch multi-tokenizer manager process # Launch multi-tokenizer manager process
tokenizer_manager = TokenizerWorker(server_args, port_args) tokenizer_manager = TokenizerWorker(server_args, port_args)
...@@ -177,8 +180,6 @@ async def init_multi_tokenizer() -> ServerArgs: ...@@ -177,8 +180,6 @@ async def init_multi_tokenizer() -> ServerArgs:
chat_template=server_args.chat_template, chat_template=server_args.chat_template,
completion_template=server_args.completion_template, completion_template=server_args.completion_template,
) )
# Register this tokenizer with the main tokenizer manager
await tokenizer_manager.register_to_main_tokenizer_manager()
tokenizer_manager.max_req_input_len = scheduler_info["max_req_input_len"] tokenizer_manager.max_req_input_len = scheduler_info["max_req_input_len"]
set_global_state( set_global_state(
......
...@@ -31,7 +31,6 @@ from sglang.srt.managers.io_struct import ( ...@@ -31,7 +31,6 @@ from sglang.srt.managers.io_struct import (
BatchStrOutput, BatchStrOutput,
BatchTokenIDOutput, BatchTokenIDOutput,
FreezeGCReq, FreezeGCReq,
MultiTokenizerRegisterReq,
) )
from sglang.srt.managers.multi_tokenizer_mixin import MultiHttpWorkerDetokenizerMixin from sglang.srt.managers.multi_tokenizer_mixin import MultiHttpWorkerDetokenizerMixin
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
...@@ -104,7 +103,6 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin): ...@@ -104,7 +103,6 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
(BatchEmbeddingOutput, self.handle_batch_embedding_out), (BatchEmbeddingOutput, self.handle_batch_embedding_out),
(BatchTokenIDOutput, self.handle_batch_token_id_out), (BatchTokenIDOutput, self.handle_batch_token_id_out),
(BatchMultimodalDecodeReq, self.handle_multimodal_decode_req), (BatchMultimodalDecodeReq, self.handle_multimodal_decode_req),
(MultiTokenizerRegisterReq, lambda x: x),
(FreezeGCReq, self.handle_freeze_gc_req), (FreezeGCReq, self.handle_freeze_gc_req),
] ]
) )
...@@ -227,6 +225,7 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin): ...@@ -227,6 +225,7 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
return BatchStrOutput( return BatchStrOutput(
rids=recv_obj.rids, rids=recv_obj.rids,
http_worker_ipcs=recv_obj.http_worker_ipcs,
finished_reasons=recv_obj.finished_reasons, finished_reasons=recv_obj.finished_reasons,
output_strs=output_strs, output_strs=output_strs,
output_ids=recv_obj.decode_ids, output_ids=recv_obj.decode_ids,
...@@ -258,6 +257,7 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin): ...@@ -258,6 +257,7 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
outputs = self.tokenizer.detokenize(recv_obj) outputs = self.tokenizer.detokenize(recv_obj)
return BatchMultimodalOutput( return BatchMultimodalOutput(
rids=recv_obj.rids, rids=recv_obj.rids,
http_worker_ipcs=recv_obj.http_worker_ipcs,
finished_reasons=recv_obj.finished_reasons, finished_reasons=recv_obj.finished_reasons,
outputs=outputs, outputs=outputs,
prompt_tokens=recv_obj.prompt_tokens, prompt_tokens=recv_obj.prompt_tokens,
......
...@@ -39,6 +39,7 @@ else: ...@@ -39,6 +39,7 @@ else:
@dataclass @dataclass
class BaseReq(ABC): class BaseReq(ABC):
rid: Optional[Union[str, List[str]]] = field(default=None, kw_only=True) rid: Optional[Union[str, List[str]]] = field(default=None, kw_only=True)
http_worker_ipc: Optional[str] = field(default=None, kw_only=True)
def regenerate_rid(self): def regenerate_rid(self):
"""Generate a new request ID and return it.""" """Generate a new request ID and return it."""
...@@ -52,6 +53,7 @@ class BaseReq(ABC): ...@@ -52,6 +53,7 @@ class BaseReq(ABC):
@dataclass @dataclass
class BaseBatchReq(ABC): class BaseBatchReq(ABC):
rids: Optional[List[str]] = field(default=None, kw_only=True) rids: Optional[List[str]] = field(default=None, kw_only=True)
http_worker_ipcs: Optional[List[str]] = field(default=None, kw_only=True)
def regenerate_rids(self): def regenerate_rids(self):
"""Generate new request IDs and return them.""" """Generate new request IDs and return them."""
...@@ -1407,18 +1409,6 @@ class LoRAUpdateOutput(BaseReq): ...@@ -1407,18 +1409,6 @@ class LoRAUpdateOutput(BaseReq):
LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateOutput LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateOutput
@dataclass
class MultiTokenizerRegisterReq(BaseBatchReq):
ipc_name: Optional[str] = None
@dataclass
class MultiTokenizerWrapper:
# FIXME(lsyin): remove this
worker_id: int
obj: Optional[Any] = None
class BlockReqType(Enum): class BlockReqType(Enum):
BLOCK = 1 BLOCK = 1
UNBLOCK = 2 UNBLOCK = 2
......
from __future__ import annotations
# Copyright 2023-2024 SGLang Team # Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -21,7 +23,7 @@ import sys ...@@ -21,7 +23,7 @@ import sys
import threading import threading
from functools import partialmethod from functools import partialmethod
from multiprocessing import shared_memory from multiprocessing import shared_memory
from typing import Any, Dict from typing import TYPE_CHECKING, Any, Dict, Union
import setproctitle import setproctitle
import zmq import zmq
...@@ -30,12 +32,12 @@ import zmq.asyncio ...@@ -30,12 +32,12 @@ import zmq.asyncio
from sglang.srt.disaggregation.utils import DisaggregationMode, TransferBackend from sglang.srt.disaggregation.utils import DisaggregationMode, TransferBackend
from sglang.srt.managers.disagg_service import start_disagg_service from sglang.srt.managers.disagg_service import start_disagg_service
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
BaseBatchReq,
BaseReq,
BatchEmbeddingOutput, BatchEmbeddingOutput,
BatchMultimodalOutput, BatchMultimodalOutput,
BatchStrOutput, BatchStrOutput,
BatchTokenIDOutput, BatchTokenIDOutput,
MultiTokenizerRegisterReq,
MultiTokenizerWrapper,
) )
from sglang.srt.managers.tokenizer_communicator_mixin import _Communicator from sglang.srt.managers.tokenizer_communicator_mixin import _Communicator
from sglang.srt.managers.tokenizer_manager import TokenizerManager from sglang.srt.managers.tokenizer_manager import TokenizerManager
...@@ -43,6 +45,9 @@ from sglang.srt.server_args import PortArgs, ServerArgs ...@@ -43,6 +45,9 @@ from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import get_zmq_socket, kill_process_tree from sglang.srt.utils import get_zmq_socket, kill_process_tree
from sglang.utils import get_exception_traceback from sglang.utils import get_exception_traceback
if TYPE_CHECKING:
from sglang.srt.managers.detokenizer_manager import DetokenizerManager
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -56,29 +61,24 @@ class SocketMapping: ...@@ -56,29 +61,24 @@ class SocketMapping:
socket.close() socket.close()
self._mapping.clear() self._mapping.clear()
def register_ipc_mapping( def _register_ipc_mapping(self, ipc_name: str, is_tokenizer: bool):
self, recv_obj: MultiTokenizerRegisterReq, worker_id: str, is_tokenizer: bool
):
type_str = "tokenizer" if is_tokenizer else "detokenizer" type_str = "tokenizer" if is_tokenizer else "detokenizer"
if worker_id in self._mapping: if ipc_name in self._mapping:
logger.warning( logger.warning(f"{type_str} already registered {ipc_name=}, skipping...")
f"{type_str} already registered with worker {worker_id}, skipping..."
)
return return
logger.info( logger.info(f"Registering {type_str} {ipc_name=} in SocketMapping...")
f"{type_str} not registered with worker {worker_id}, registering..." socket = get_zmq_socket(self._zmq_context, zmq.PUSH, ipc_name, False)
) self._mapping[ipc_name] = socket
socket = get_zmq_socket(self._zmq_context, zmq.PUSH, recv_obj.ipc_name, False)
self._mapping[worker_id] = socket def send_output(self, ipc_name: str, output: Any):
self._mapping[worker_id].send_pyobj(recv_obj) if ipc_name is None:
# Some unhandled cases
def send_output(self, worker_id: str, output: Any): logger.warning(f"IPC name is None, output type={type(output)}, skipping...")
if worker_id not in self._mapping:
logger.error(
f"worker ID {worker_id} not registered. Check if the server Process is alive"
)
return return
self._mapping[worker_id].send_pyobj(output)
if ipc_name not in self._mapping:
self._register_ipc_mapping(ipc_name, is_tokenizer=False)
self._mapping[ipc_name].send_pyobj(output)
def _handle_output_by_index(output, i): def _handle_output_by_index(output, i):
...@@ -362,20 +362,11 @@ def _handle_output_by_index(output, i): ...@@ -362,20 +362,11 @@ def _handle_output_by_index(output, i):
class MultiHttpWorkerDetokenizerMixin: class MultiHttpWorkerDetokenizerMixin:
"""Mixin class for DetokenizerManager""" """Mixin class for DetokenizerManager"""
def get_worker_ids_from_req_rids(self, rids): def maybe_clear_socket_mapping(self: DetokenizerManager):
if isinstance(rids, list):
worker_ids = [int(rid.split("_")[0]) for rid in rids]
elif isinstance(rids, str):
worker_ids = [int(rids.split("_")[0])]
else:
worker_ids = []
return worker_ids
def maybe_clear_socket_mapping(self):
if hasattr(self, "socket_mapping"): if hasattr(self, "socket_mapping"):
self.socket_mapping.clear_all_sockets() self.socket_mapping.clear_all_sockets()
def multi_http_worker_event_loop(self): def multi_http_worker_event_loop(self: DetokenizerManager):
"""The event loop that handles requests, for multi multi-http-worker mode""" """The event loop that handles requests, for multi multi-http-worker mode"""
self.socket_mapping = SocketMapping() self.socket_mapping = SocketMapping()
while True: while True:
...@@ -383,23 +374,15 @@ class MultiHttpWorkerDetokenizerMixin: ...@@ -383,23 +374,15 @@ class MultiHttpWorkerDetokenizerMixin:
output = self._request_dispatcher(recv_obj) output = self._request_dispatcher(recv_obj)
if output is None: if output is None:
continue continue
# Extract worker_id from rid
if isinstance(recv_obj.rids, list): assert isinstance(
worker_ids = self.get_worker_ids_from_req_rids(recv_obj.rids) recv_obj, BaseBatchReq
else: ), "for multi-http-worker, recv_obj must be BaseBatchReq"
raise RuntimeError(
f"for tokenizer_worker_num > 1, recv_obj.rids must be a list"
)
# Send data using the corresponding socket # Send data using the corresponding socket
for i, worker_id in enumerate(worker_ids): for i, ipc_name in enumerate(recv_obj.http_worker_ipcs):
if isinstance(recv_obj, MultiTokenizerRegisterReq): new_output = _handle_output_by_index(output, i)
self.socket_mapping.register_ipc_mapping( self.socket_mapping.send_output(ipc_name, new_output)
recv_obj, worker_id, is_tokenizer=False
)
else:
new_output = _handle_output_by_index(output, i)
self.socket_mapping.send_output(worker_id, new_output)
class MultiTokenizerRouter: class MultiTokenizerRouter:
...@@ -449,26 +432,17 @@ class MultiTokenizerRouter: ...@@ -449,26 +432,17 @@ class MultiTokenizerRouter:
await self._distribute_result_to_workers(recv_obj) await self._distribute_result_to_workers(recv_obj)
async def _distribute_result_to_workers(self, recv_obj): async def _distribute_result_to_workers(self, recv_obj):
"""Distribute result to corresponding workers based on rid""" # Distribute result to each worker
if isinstance(recv_obj, MultiTokenizerWrapper): if isinstance(recv_obj, BaseReq):
worker_ids = [recv_obj.worker_id] ipc_names = [recv_obj.http_worker_ipc]
recv_obj = recv_obj.obj elif isinstance(recv_obj, BaseBatchReq):
ipc_names = recv_obj.http_worker_ipcs
else: else:
worker_ids = self.get_worker_ids_from_req_rids(recv_obj.rids) raise ValueError(f"Unknown recv_obj type: {type(recv_obj)}")
if len(worker_ids) == 0:
logger.error(f"Cannot find worker_id from rids {recv_obj.rids}")
return
# Distribute result to each worker for i, ipc_name in enumerate(ipc_names):
for i, worker_id in enumerate(worker_ids): new_recv_obj = _handle_output_by_index(recv_obj, i)
if isinstance(recv_obj, MultiTokenizerRegisterReq): self.socket_mapping.send_output(ipc_name, new_recv_obj)
self.socket_mapping.register_ipc_mapping(
recv_obj, worker_id, is_tokenizer=True
)
else:
new_recv_obj = _handle_output_by_index(recv_obj, i)
self.socket_mapping.send_output(worker_id, new_recv_obj)
class TokenizerWorker(TokenizerManager): class TokenizerWorker(TokenizerManager):
...@@ -500,21 +474,15 @@ class TokenizerWorker(TokenizerManager): ...@@ -500,21 +474,15 @@ class TokenizerWorker(TokenizerManager):
self.register_multi_tokenizer_communicator = _Communicator( self.register_multi_tokenizer_communicator = _Communicator(
self.send_to_scheduler, 2 self.send_to_scheduler, 2
) )
self._result_dispatcher._mapping.append(
(
MultiTokenizerRegisterReq,
self.register_multi_tokenizer_communicator.handle_recv,
)
)
async def register_to_main_tokenizer_manager(self): def _attach_multi_http_worker_info(self, req: Union[BaseReq, BaseBatchReq]):
"""Register this worker to the main TokenizerManager"""
# create a handle loop to receive messages from the main TokenizerManager if isinstance(req, BaseReq):
self.auto_create_handle_loop() req.http_worker_ipc = self.tokenizer_ipc_name
req = MultiTokenizerRegisterReq(rids=[f"{self.worker_id}_register"]) elif isinstance(req, BaseBatchReq):
req.ipc_name = self.tokenizer_ipc_name req.http_worker_ipcs = [self.tokenizer_ipc_name] * len(req.rids)
_Communicator.enable_multi_tokenizer = True else:
await self.register_multi_tokenizer_communicator(req) raise ValueError(f"Unknown req type: {type(req)}")
async def print_exception_wrapper(func): async def print_exception_wrapper(func):
......
...@@ -438,6 +438,7 @@ class Req: ...@@ -438,6 +438,7 @@ class Req:
priority: Optional[int] = None, priority: Optional[int] = None,
metrics_collector: Optional[SchedulerMetricsCollector] = None, metrics_collector: Optional[SchedulerMetricsCollector] = None,
extra_key: Optional[str] = None, extra_key: Optional[str] = None,
http_worker_ipc: Optional[str] = None,
): ):
# Input and output info # Input and output info
self.rid = rid self.rid = rid
...@@ -461,6 +462,9 @@ class Req: ...@@ -461,6 +462,9 @@ class Req:
# The length of KV that have been removed in local attention chunked prefill # The length of KV that have been removed in local attention chunked prefill
self.evicted_seqlen_local = 0 self.evicted_seqlen_local = 0
# For multi-http worker
self.http_worker_ipc = http_worker_ipc
# Sampling info # Sampling info
if isinstance(sampling_params.custom_params, dict): if isinstance(sampling_params.custom_params, dict):
sampling_params = copy.copy(sampling_params) sampling_params = copy.copy(sampling_params)
......
...@@ -24,7 +24,6 @@ from collections import deque ...@@ -24,7 +24,6 @@ from collections import deque
from concurrent import futures from concurrent import futures
from dataclasses import dataclass from dataclasses import dataclass
from http import HTTPStatus from http import HTTPStatus
from types import SimpleNamespace
from typing import Deque, Dict, List, Optional, Tuple, Union from typing import Deque, Dict, List, Optional, Tuple, Union
import psutil import psutil
...@@ -66,6 +65,8 @@ from sglang.srt.layers.dp_attention import compute_dp_attention_world_info ...@@ -66,6 +65,8 @@ from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
from sglang.srt.layers.moe import initialize_moe_config from sglang.srt.layers.moe import initialize_moe_config
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
AbortReq, AbortReq,
BaseBatchReq,
BaseReq,
BatchTokenizedEmbeddingReqInput, BatchTokenizedEmbeddingReqInput,
BatchTokenizedGenerateReqInput, BatchTokenizedGenerateReqInput,
ClearHiCacheReqInput, ClearHiCacheReqInput,
...@@ -89,8 +90,6 @@ from sglang.srt.managers.io_struct import ( ...@@ -89,8 +90,6 @@ from sglang.srt.managers.io_struct import (
InitWeightsUpdateGroupReqInput, InitWeightsUpdateGroupReqInput,
LoadLoRAAdapterReqInput, LoadLoRAAdapterReqInput,
LoadLoRAAdapterReqOutput, LoadLoRAAdapterReqOutput,
MultiTokenizerRegisterReq,
MultiTokenizerWrapper,
OpenSessionReqInput, OpenSessionReqInput,
OpenSessionReqOutput, OpenSessionReqOutput,
ProfileReq, ProfileReq,
...@@ -277,47 +276,7 @@ class Scheduler( ...@@ -277,47 +276,7 @@ class Scheduler(
self.model_config = ModelConfig.from_server_args(server_args) self.model_config = ModelConfig.from_server_args(server_args)
# Init inter-process communication # Init inter-process communication
context = zmq.Context(2) self.init_sockets(server_args, port_args)
self.idle_sleeper = None
if self.pp_rank == 0 and self.attn_tp_rank == 0:
self.recv_from_tokenizer = get_zmq_socket(
context, zmq.PULL, port_args.scheduler_input_ipc_name, False
)
self.recv_from_rpc = get_zmq_socket(
context, zmq.DEALER, port_args.rpc_ipc_name, False
)
self.send_to_tokenizer = get_zmq_socket(
context, zmq.PUSH, port_args.tokenizer_ipc_name, False
)
if server_args.skip_tokenizer_init:
# Directly send to the TokenizerManager
self.send_to_detokenizer = get_zmq_socket(
context, zmq.PUSH, port_args.tokenizer_ipc_name, False
)
else:
# Send to the DetokenizerManager
self.send_to_detokenizer = get_zmq_socket(
context, zmq.PUSH, port_args.detokenizer_ipc_name, False
)
if self.server_args.sleep_on_idle:
self.idle_sleeper = IdleSleeper(
[
self.recv_from_tokenizer,
self.recv_from_rpc,
]
)
else:
self.recv_from_tokenizer = None
self.recv_from_rpc = None
self.send_to_tokenizer = SimpleNamespace(send_pyobj=lambda x: None)
self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
if self.current_scheduler_metrics_enabled():
self.send_metrics_from_scheduler = get_zmq_socket(
context, zmq.PUSH, port_args.metrics_ipc_name, False
)
# Init tokenizer # Init tokenizer
self.init_tokenizer() self.init_tokenizer()
...@@ -578,7 +537,6 @@ class Scheduler( ...@@ -578,7 +537,6 @@ class Scheduler(
(ExpertDistributionReq, self.expert_distribution_handle), (ExpertDistributionReq, self.expert_distribution_handle),
(LoadLoRAAdapterReqInput, self.load_lora_adapter), (LoadLoRAAdapterReqInput, self.load_lora_adapter),
(UnloadLoRAAdapterReqInput, self.unload_lora_adapter), (UnloadLoRAAdapterReqInput, self.unload_lora_adapter),
(MultiTokenizerRegisterReq, self.register_multi_tokenizer),
(GetLoadReqInput, self.get_load), (GetLoadReqInput, self.get_load),
] ]
) )
...@@ -634,6 +592,75 @@ class Scheduler( ...@@ -634,6 +592,75 @@ class Scheduler(
else: else:
self.draft_worker = None self.draft_worker = None
def init_sockets(self, server_args: ServerArgs, port_args: PortArgs):
context = zmq.Context(2)
self.idle_sleeper = None
class SenderWrapper:
def __init__(self, socket: zmq.Socket):
self.socket = socket
def send_output(
self,
output: Union[BaseReq, BaseBatchReq],
recv_obj: Optional[Union[BaseReq, BaseBatchReq]] = None,
):
if self.socket is None:
return
if (
isinstance(recv_obj, BaseReq)
and recv_obj.http_worker_ipc is not None
and output.http_worker_ipc is None
):
# handle communicator reqs for multi-http worker case
output.http_worker_ipc = recv_obj.http_worker_ipc
self.socket.send_pyobj(output)
if self.pp_rank == 0 and self.attn_tp_rank == 0:
self.recv_from_tokenizer = get_zmq_socket(
context, zmq.PULL, port_args.scheduler_input_ipc_name, False
)
self.recv_from_rpc = get_zmq_socket(
context, zmq.DEALER, port_args.rpc_ipc_name, False
)
send_to_tokenizer = get_zmq_socket(
context, zmq.PUSH, port_args.tokenizer_ipc_name, False
)
if server_args.skip_tokenizer_init:
# Directly send to the TokenizerManager
send_to_detokenizer = get_zmq_socket(
context, zmq.PUSH, port_args.tokenizer_ipc_name, False
)
else:
# Send to the DetokenizerManager
send_to_detokenizer = get_zmq_socket(
context, zmq.PUSH, port_args.detokenizer_ipc_name, False
)
self.send_to_tokenizer = SenderWrapper(send_to_tokenizer)
self.send_to_detokenizer = SenderWrapper(send_to_detokenizer)
if self.server_args.sleep_on_idle:
self.idle_sleeper = IdleSleeper(
[
self.recv_from_tokenizer,
self.recv_from_rpc,
]
)
else:
self.recv_from_tokenizer = None
self.recv_from_rpc = None
self.send_to_tokenizer = SenderWrapper(None)
self.send_to_detokenizer = SenderWrapper(None)
if self.current_scheduler_metrics_enabled():
self.send_metrics_from_scheduler = get_zmq_socket(
context, zmq.PUSH, port_args.metrics_ipc_name, False
)
def init_deterministic_inference_config(self): def init_deterministic_inference_config(self):
"""Initialize deterministic inference configuration for different attention backends.""" """Initialize deterministic inference configuration for different attention backends."""
if not self.server_args.enable_deterministic_inference: if not self.server_args.enable_deterministic_inference:
...@@ -1107,23 +1134,13 @@ class Scheduler( ...@@ -1107,23 +1134,13 @@ class Scheduler(
self.return_health_check_ct += 1 self.return_health_check_ct += 1
continue continue
# If it is a MultiTokenizerWrapper, unwrap it and handle the inner request.
if isinstance(recv_req, MultiTokenizerWrapper):
worker_id = recv_req.worker_id
recv_req = recv_req.obj
output = self._request_dispatcher(recv_req)
if output is not None:
output = MultiTokenizerWrapper(worker_id, output)
self.send_to_tokenizer.send_pyobj(output)
continue
output = self._request_dispatcher(recv_req) output = self._request_dispatcher(recv_req)
if output is not None: if output is not None:
if isinstance(output, RpcReqOutput): if isinstance(output, RpcReqOutput):
if self.recv_from_rpc is not None: if self.recv_from_rpc is not None:
self.recv_from_rpc.send_pyobj(output) self.recv_from_rpc.send_pyobj(output)
else: else:
self.send_to_tokenizer.send_pyobj(output) self.send_to_tokenizer.send_output(output, recv_req)
def init_req_max_new_tokens(self, req): def init_req_max_new_tokens(self, req):
req.sampling_params.max_new_tokens = min( req.sampling_params.max_new_tokens = min(
...@@ -1179,6 +1196,7 @@ class Scheduler( ...@@ -1179,6 +1196,7 @@ class Scheduler(
metrics_collector=( metrics_collector=(
self.metrics_collector if self.enable_metrics else None self.metrics_collector if self.enable_metrics else None
), ),
http_worker_ipc=recv_req.http_worker_ipc,
) )
req.tokenizer = self.tokenizer req.tokenizer = self.tokenizer
...@@ -1382,7 +1400,7 @@ class Scheduler( ...@@ -1382,7 +1400,7 @@ class Scheduler(
}, },
rid=req.rid, rid=req.rid,
) )
self.send_to_tokenizer.send_pyobj(abort_req) self.send_to_tokenizer.send_output(abort_req, req)
def _abort_on_queued_limit(self, recv_req: Req) -> bool: def _abort_on_queued_limit(self, recv_req: Req) -> bool:
"""Abort an incoming or existing request if the waiting queue is full. Returns True if the incoming request is aborted.""" """Abort an incoming or existing request if the waiting queue is full. Returns True if the incoming request is aborted."""
...@@ -1414,7 +1432,7 @@ class Scheduler( ...@@ -1414,7 +1432,7 @@ class Scheduler(
req_to_abort = candidate_req req_to_abort = candidate_req
message = "The request is aborted by a higher priority request." message = "The request is aborted by a higher priority request."
self.send_to_tokenizer.send_pyobj( self.send_to_tokenizer.send_output(
AbortReq( AbortReq(
finished_reason={ finished_reason={
"type": "abort", "type": "abort",
...@@ -1422,7 +1440,8 @@ class Scheduler( ...@@ -1422,7 +1440,8 @@ class Scheduler(
"message": message, "message": message,
}, },
rid=req_to_abort.rid, rid=req_to_abort.rid,
) ),
req_to_abort,
) )
return req_to_abort.rid == recv_req.rid return req_to_abort.rid == recv_req.rid
...@@ -1437,6 +1456,7 @@ class Scheduler( ...@@ -1437,6 +1456,7 @@ class Scheduler(
recv_req.sampling_params, recv_req.sampling_params,
token_type_ids=recv_req.token_type_ids, token_type_ids=recv_req.token_type_ids,
priority=recv_req.priority, priority=recv_req.priority,
http_worker_ipc=recv_req.http_worker_ipc,
) )
req.tokenizer = self.tokenizer req.tokenizer = self.tokenizer
...@@ -1953,8 +1973,8 @@ class Scheduler( ...@@ -1953,8 +1973,8 @@ class Scheduler(
self.num_retracted_reqs = len(retracted_reqs) self.num_retracted_reqs = len(retracted_reqs)
self.new_token_ratio = new_token_ratio self.new_token_ratio = new_token_ratio
for req in reqs_to_abort: for req in reqs_to_abort:
self.send_to_tokenizer.send_pyobj( self.send_to_tokenizer.send_output(
AbortReq(abort_reason=req.to_abort_message, rid=req.rid) AbortReq(abort_reason=req.to_abort_message, rid=req.rid), req
) )
logger.info( logger.info(
...@@ -2138,7 +2158,7 @@ class Scheduler( ...@@ -2138,7 +2158,7 @@ class Scheduler(
# This is used to prevent the health check signal being blocked by long context prefill. # This is used to prevent the health check signal being blocked by long context prefill.
# However, one minor issue is that this code path does not check the status of detokenizer manager. # However, one minor issue is that this code path does not check the status of detokenizer manager.
self.return_health_check_ct -= 1 self.return_health_check_ct -= 1
self.send_to_tokenizer.send_pyobj(HealthCheckOutput()) self.send_to_tokenizer.send_output(HealthCheckOutput())
def prepare_mlp_sync_batch(self, local_batch: ScheduleBatch): def prepare_mlp_sync_batch(self, local_batch: ScheduleBatch):
return self.prepare_mlp_sync_batch_raw( return self.prepare_mlp_sync_batch_raw(
...@@ -2585,7 +2605,7 @@ class Scheduler( ...@@ -2585,7 +2605,7 @@ class Scheduler(
if self.enable_hicache_storage: if self.enable_hicache_storage:
# to release prefetch events associated with the request # to release prefetch events associated with the request
self.tree_cache.release_aborted_request(req.rid) self.tree_cache.release_aborted_request(req.rid)
self.send_to_tokenizer.send_pyobj(AbortReq(rid=req.rid)) self.send_to_tokenizer.send_output(AbortReq(rid=req.rid), req)
# For disaggregation decode mode, the request in the waiting queue has KV cache allocated. # For disaggregation decode mode, the request in the waiting queue has KV cache allocated.
if self.disaggregation_mode == DisaggregationMode.DECODE: if self.disaggregation_mode == DisaggregationMode.DECODE:
self.tree_cache.cache_finished_req(req) self.tree_cache.cache_finished_req(req)
...@@ -2669,10 +2689,6 @@ class Scheduler( ...@@ -2669,10 +2689,6 @@ class Scheduler(
result = self.tp_worker.unload_lora_adapter(recv_req) result = self.tp_worker.unload_lora_adapter(recv_req)
return result return result
def register_multi_tokenizer(self, recv_req: MultiTokenizerRegisterReq):
self.send_to_detokenizer.send_pyobj(recv_req)
return recv_req
def init_weights_send_group_for_remote_instance( def init_weights_send_group_for_remote_instance(
self, recv_req: InitWeightsSendGroupForRemoteInstanceReqInput self, recv_req: InitWeightsSendGroupForRemoteInstanceReqInput
): ):
...@@ -2751,7 +2767,7 @@ class Scheduler( ...@@ -2751,7 +2767,7 @@ class Scheduler(
def handle_freeze_gc(self, recv_req: FreezeGCReq): def handle_freeze_gc(self, recv_req: FreezeGCReq):
"""Handle freeze_gc request: freeze scheduler's GC and forward to detokenizer.""" """Handle freeze_gc request: freeze scheduler's GC and forward to detokenizer."""
freeze_gc("Scheduler") freeze_gc("Scheduler")
self.send_to_detokenizer.send_pyobj(recv_req) self.send_to_detokenizer.send_output(recv_req, recv_req)
return None return None
......
...@@ -682,6 +682,7 @@ class SchedulerOutputProcessorMixin: ...@@ -682,6 +682,7 @@ class SchedulerOutputProcessorMixin:
skip_req: Optional[Req] = None, skip_req: Optional[Req] = None,
): ):
rids = [] rids = []
http_worker_ipcs = []
finished_reasons: List[BaseFinishReason] = [] finished_reasons: List[BaseFinishReason] = []
decoded_texts = [] decoded_texts = []
...@@ -770,6 +771,7 @@ class SchedulerOutputProcessorMixin: ...@@ -770,6 +771,7 @@ class SchedulerOutputProcessorMixin:
req.send_output_token_logprobs_offset req.send_output_token_logprobs_offset
) )
rids.append(req.rid) rids.append(req.rid)
http_worker_ipcs.append(req.http_worker_ipc)
finished_reasons.append( finished_reasons.append(
req.finished_reason.to_json() if req.finished_reason else None req.finished_reason.to_json() if req.finished_reason else None
) )
...@@ -886,7 +888,7 @@ class SchedulerOutputProcessorMixin: ...@@ -886,7 +888,7 @@ class SchedulerOutputProcessorMixin:
if self.model_config.is_multimodal_gen: if self.model_config.is_multimodal_gen:
return return
self.send_to_detokenizer.send_pyobj( self.send_to_detokenizer.send_output(
BatchTokenIDOutput( BatchTokenIDOutput(
finished_reasons, finished_reasons,
decoded_texts, decoded_texts,
...@@ -916,6 +918,7 @@ class SchedulerOutputProcessorMixin: ...@@ -916,6 +918,7 @@ class SchedulerOutputProcessorMixin:
output_token_entropy_val=None, output_token_entropy_val=None,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
rids=rids, rids=rids,
http_worker_ipcs=http_worker_ipcs,
placeholder_tokens_idx=None, placeholder_tokens_idx=None,
placeholder_tokens_val=None, placeholder_tokens_val=None,
) )
...@@ -923,6 +926,7 @@ class SchedulerOutputProcessorMixin: ...@@ -923,6 +926,7 @@ class SchedulerOutputProcessorMixin:
def stream_output_embedding(self: Scheduler, reqs: List[Req]): def stream_output_embedding(self: Scheduler, reqs: List[Req]):
rids = [] rids = []
http_worker_ipcs = []
finished_reasons: List[BaseFinishReason] = [] finished_reasons: List[BaseFinishReason] = []
embeddings = [] embeddings = []
...@@ -931,17 +935,19 @@ class SchedulerOutputProcessorMixin: ...@@ -931,17 +935,19 @@ class SchedulerOutputProcessorMixin:
for req in reqs: for req in reqs:
if req.finished(): if req.finished():
rids.append(req.rid) rids.append(req.rid)
http_worker_ipcs.append(req.http_worker_ipc)
finished_reasons.append(req.finished_reason.to_json()) finished_reasons.append(req.finished_reason.to_json())
embeddings.append(req.embedding) embeddings.append(req.embedding)
prompt_tokens.append(len(req.origin_input_ids)) prompt_tokens.append(len(req.origin_input_ids))
cached_tokens.append(req.cached_tokens) cached_tokens.append(req.cached_tokens)
self.send_to_detokenizer.send_pyobj( self.send_to_detokenizer.send_output(
BatchEmbeddingOutput( BatchEmbeddingOutput(
finished_reasons, finished_reasons,
embeddings, embeddings,
prompt_tokens, prompt_tokens,
cached_tokens, cached_tokens,
rids=rids, rids=rids,
http_worker_ipcs=http_worker_ipcs,
placeholder_tokens_idx=None, placeholder_tokens_idx=None,
placeholder_tokens_val=None, placeholder_tokens_val=None,
) )
......
...@@ -3,7 +3,6 @@ from __future__ import annotations ...@@ -3,7 +3,6 @@ from __future__ import annotations
import asyncio import asyncio
import copy import copy
import logging import logging
import os
import time import time
import uuid import uuid
from collections import deque from collections import deque
...@@ -46,7 +45,6 @@ from sglang.srt.managers.io_struct import ( ...@@ -46,7 +45,6 @@ from sglang.srt.managers.io_struct import (
LoadLoRAAdapterReqInput, LoadLoRAAdapterReqInput,
LoadLoRAAdapterReqOutput, LoadLoRAAdapterReqOutput,
LoRAUpdateOutput, LoRAUpdateOutput,
MultiTokenizerWrapper,
OpenSessionReqInput, OpenSessionReqInput,
ProfileReq, ProfileReq,
ProfileReqOutput, ProfileReqOutput,
...@@ -83,8 +81,6 @@ logger = logging.getLogger(__name__) ...@@ -83,8 +81,6 @@ logger = logging.getLogger(__name__)
class _Communicator(Generic[T]): class _Communicator(Generic[T]):
"""Note: The communicator now only run up to 1 in-flight request at any time.""" """Note: The communicator now only run up to 1 in-flight request at any time."""
enable_multi_tokenizer = False
def __init__(self, sender: zmq.Socket, fan_out: int, mode="queueing"): def __init__(self, sender: zmq.Socket, fan_out: int, mode="queueing"):
self._sender = sender self._sender = sender
self._fan_out = fan_out self._fan_out = fan_out
...@@ -104,8 +100,6 @@ class _Communicator(Generic[T]): ...@@ -104,8 +100,6 @@ class _Communicator(Generic[T]):
assert self._result_values is None assert self._result_values is None
if obj: if obj:
if _Communicator.enable_multi_tokenizer:
obj = MultiTokenizerWrapper(worker_id=os.getpid(), obj=obj)
self._sender.send_pyobj(obj) self._sender.send_pyobj(obj)
self._result_event = asyncio.Event() self._result_event = asyncio.Event()
...@@ -126,8 +120,6 @@ class _Communicator(Generic[T]): ...@@ -126,8 +120,6 @@ class _Communicator(Generic[T]):
self._result_event = asyncio.Event() self._result_event = asyncio.Event()
if obj: if obj:
if _Communicator.enable_multi_tokenizer:
obj = MultiTokenizerWrapper(worker_id=os.getpid(), obj=obj)
self._sender.send_pyobj(obj) self._sender.send_pyobj(obj)
await self._result_event.wait() await self._result_event.wait()
...@@ -617,8 +609,6 @@ class TokenizerCommunicatorMixin: ...@@ -617,8 +609,6 @@ class TokenizerCommunicatorMixin:
elif obj.session_id in self.session_futures: elif obj.session_id in self.session_futures:
return None return None
if self.server_args.tokenizer_worker_num > 1:
obj = MultiTokenizerWrapper(self.worker_id, obj)
self.send_to_scheduler.send_pyobj(obj) self.send_to_scheduler.send_pyobj(obj)
self.session_futures[obj.session_id] = asyncio.Future() self.session_futures[obj.session_id] = asyncio.Future()
......
...@@ -46,6 +46,7 @@ from sglang.srt.managers.async_dynamic_batch_tokenizer import AsyncDynamicbatchT ...@@ -46,6 +46,7 @@ from sglang.srt.managers.async_dynamic_batch_tokenizer import AsyncDynamicbatchT
from sglang.srt.managers.disagg_service import start_disagg_service from sglang.srt.managers.disagg_service import start_disagg_service
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
AbortReq, AbortReq,
BaseReq,
BatchEmbeddingOutput, BatchEmbeddingOutput,
BatchMultimodalOutput, BatchMultimodalOutput,
BatchStrOutput, BatchStrOutput,
...@@ -58,7 +59,6 @@ from sglang.srt.managers.io_struct import ( ...@@ -58,7 +59,6 @@ from sglang.srt.managers.io_struct import (
GenerateReqInput, GenerateReqInput,
GetLoadReqInput, GetLoadReqInput,
HealthCheckOutput, HealthCheckOutput,
MultiTokenizerWrapper,
OpenSessionReqOutput, OpenSessionReqOutput,
SessionParams, SessionParams,
TokenizedEmbeddingReqInput, TokenizedEmbeddingReqInput,
...@@ -88,7 +88,6 @@ from sglang.srt.utils import ( ...@@ -88,7 +88,6 @@ from sglang.srt.utils import (
dataclass_to_string_truncated, dataclass_to_string_truncated,
freeze_gc, freeze_gc,
get_bool_env_var, get_bool_env_var,
get_origin_rid,
get_zmq_socket, get_zmq_socket,
kill_process_tree, kill_process_tree,
) )
...@@ -258,9 +257,18 @@ class TokenizerManager(TokenizerCommunicatorMixin): ...@@ -258,9 +257,18 @@ class TokenizerManager(TokenizerCommunicatorMixin):
) )
if self.server_args.tokenizer_worker_num > 1: if self.server_args.tokenizer_worker_num > 1:
# Use tokenizer_worker_ipc_name in multi-tokenizer mode # Use tokenizer_worker_ipc_name in multi-tokenizer mode
self.send_to_scheduler = get_zmq_socket( send_to_scheduler = get_zmq_socket(
context, zmq.PUSH, port_args.tokenizer_worker_ipc_name, False context, zmq.PUSH, port_args.tokenizer_worker_ipc_name, False
) )
class SenderWrapper:
def send_pyobj(self, obj):
if isinstance(obj, BaseReq):
obj.http_worker_ipc = port_args.tokenizer_ipc_name
send_to_scheduler.send_pyobj(obj)
# Make sure that each request carries the tokenizer_ipc_name for response routing
self.send_to_scheduler = SenderWrapper()
else: else:
self.send_to_scheduler = get_zmq_socket( self.send_to_scheduler = get_zmq_socket(
context, zmq.PUSH, port_args.scheduler_input_ipc_name, True context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
...@@ -376,13 +384,10 @@ class TokenizerManager(TokenizerCommunicatorMixin): ...@@ -376,13 +384,10 @@ class TokenizerManager(TokenizerCommunicatorMixin):
obj.normalize_batch_and_arguments() obj.normalize_batch_and_arguments()
if self.server_args.tokenizer_worker_num > 1: if self.server_args.tokenizer_worker_num > 1:
# Modify rid, add worker_id from sglang.srt.managers.multi_tokenizer_mixin import TokenizerWorker
if isinstance(obj.rid, list):
# If it's an array, add worker_id prefix to each element assert isinstance(self, TokenizerWorker)
obj.rid = [f"{self.worker_id}_{rid}" for rid in obj.rid] self._attach_multi_http_worker_info(obj)
else:
# If it's a single value, add worker_id prefix
obj.rid = f"{self.worker_id}_{obj.rid}"
if self.enable_trace: if self.enable_trace:
self._trace_request_start(obj, created_time) self._trace_request_start(obj, created_time)
...@@ -728,6 +733,7 @@ class TokenizerManager(TokenizerCommunicatorMixin): ...@@ -728,6 +733,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
obj.token_ids_logprob, obj.token_ids_logprob,
obj.stream, obj.stream,
rid=obj.rid, rid=obj.rid,
http_worker_ipc=obj.http_worker_ipc,
bootstrap_host=obj.bootstrap_host, bootstrap_host=obj.bootstrap_host,
bootstrap_port=obj.bootstrap_port, bootstrap_port=obj.bootstrap_port,
bootstrap_room=obj.bootstrap_room, bootstrap_room=obj.bootstrap_room,
...@@ -749,6 +755,7 @@ class TokenizerManager(TokenizerCommunicatorMixin): ...@@ -749,6 +755,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
sampling_params, sampling_params,
rid=obj.rid, rid=obj.rid,
priority=obj.priority, priority=obj.priority,
http_worker_ipc=obj.http_worker_ipc,
) )
return tokenized_obj return tokenized_obj
...@@ -1109,8 +1116,6 @@ class TokenizerManager(TokenizerCommunicatorMixin): ...@@ -1109,8 +1116,6 @@ class TokenizerManager(TokenizerCommunicatorMixin):
async def _wait_for_model_update_from_disk( async def _wait_for_model_update_from_disk(
self, obj: UpdateWeightFromDiskReqInput self, obj: UpdateWeightFromDiskReqInput
) -> Tuple[bool, str]: ) -> Tuple[bool, str]:
if self.server_args.tokenizer_worker_num > 1:
obj = MultiTokenizerWrapper(self.worker_id, obj)
self.send_to_scheduler.send_pyobj(obj) self.send_to_scheduler.send_pyobj(obj)
self.model_update_result = asyncio.Future() self.model_update_result = asyncio.Future()
if self.server_args.dp_size == 1: if self.server_args.dp_size == 1:
...@@ -1349,12 +1354,9 @@ class TokenizerManager(TokenizerCommunicatorMixin): ...@@ -1349,12 +1354,9 @@ class TokenizerManager(TokenizerCommunicatorMixin):
) )
continue continue
origin_rid = rid
if self.server_args.tokenizer_worker_num > 1:
origin_rid = get_origin_rid(rid)
# Build meta_info and return value # Build meta_info and return value
meta_info = { meta_info = {
"id": origin_rid, "id": rid,
"finish_reason": recv_obj.finished_reasons[i], "finish_reason": recv_obj.finished_reasons[i],
"prompt_tokens": recv_obj.prompt_tokens[i], "prompt_tokens": recv_obj.prompt_tokens[i],
"weight_version": self.server_args.weight_version, "weight_version": self.server_args.weight_version,
...@@ -1708,9 +1710,6 @@ class TokenizerManager(TokenizerCommunicatorMixin): ...@@ -1708,9 +1710,6 @@ class TokenizerManager(TokenizerCommunicatorMixin):
if is_health_check_generate_req(recv_obj): if is_health_check_generate_req(recv_obj):
return return
state = self.rid_to_state[recv_obj.rid] state = self.rid_to_state[recv_obj.rid]
origin_rid = recv_obj.rid
if self.server_args.tokenizer_worker_num > 1:
origin_rid = get_origin_rid(origin_rid)
state.finished = True state.finished = True
if recv_obj.finished_reason: if recv_obj.finished_reason:
out = { out = {
...@@ -1723,7 +1722,7 @@ class TokenizerManager(TokenizerCommunicatorMixin): ...@@ -1723,7 +1722,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
out = { out = {
"text": "", "text": "",
"meta_info": { "meta_info": {
"id": origin_rid, "id": recv_obj.rid,
"finish_reason": { "finish_reason": {
"type": "abort", "type": "abort",
"message": "Abort before prefill", "message": "Abort before prefill",
......
...@@ -3006,10 +3006,6 @@ def lru_cache_frozenset(maxsize=128): ...@@ -3006,10 +3006,6 @@ def lru_cache_frozenset(maxsize=128):
return decorator return decorator
def get_origin_rid(rid):
return rid.split("_", 1)[1] if "_" in rid else rid
def apply_module_patch(target_module, target_function, wrappers): def apply_module_patch(target_module, target_function, wrappers):
original_module, original_function = parse_module_path( original_module, original_function = parse_module_path(
target_module, target_function, False target_module, target_function, False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment