Unverified Commit 260fe755 authored by Zhengke Zhou's avatar Zhengke Zhou Committed by GitHub
Browse files

Simplify multi-tokenizer (#11295)


Signed-off-by: default avatarzhengkezhou1 <madzhou1@gmail.com>
Co-authored-by: default avatarLiangsheng Yin <lsyincs@gmail.com>
parent dbb16bed
......@@ -149,15 +149,14 @@ def set_global_state(global_state: _GlobalState):
async def init_multi_tokenizer() -> ServerArgs:
"""Read args information from shm and init tokenizer manager for current process"""
pid = os.getpid()
main_pid = get_main_process_id()
logger.info(f"current worker_id: {pid}, main processID: {main_pid}")
# Read configuration from shared memory
main_pid = get_main_process_id()
port_args, server_args, scheduler_info = read_from_shared_memory(
f"multi_tokenizer_args_{main_pid}"
)
server_args: ServerArgs
port_args: PortArgs
# API key authentication is not supported in multi-tokenizer mode
assert (
......@@ -167,6 +166,10 @@ async def init_multi_tokenizer() -> ServerArgs:
port_args.tokenizer_ipc_name = (
f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}"
)
logger.info(
f"Start multi-tokenizer worker process {os.getpid()}, "
f"ipc_name={port_args.tokenizer_ipc_name}"
)
# Launch multi-tokenizer manager process
tokenizer_manager = TokenizerWorker(server_args, port_args)
......@@ -177,8 +180,6 @@ async def init_multi_tokenizer() -> ServerArgs:
chat_template=server_args.chat_template,
completion_template=server_args.completion_template,
)
# Register this tokenizer with the main tokenizer manager
await tokenizer_manager.register_to_main_tokenizer_manager()
tokenizer_manager.max_req_input_len = scheduler_info["max_req_input_len"]
set_global_state(
......
......@@ -31,7 +31,6 @@ from sglang.srt.managers.io_struct import (
BatchStrOutput,
BatchTokenIDOutput,
FreezeGCReq,
MultiTokenizerRegisterReq,
)
from sglang.srt.managers.multi_tokenizer_mixin import MultiHttpWorkerDetokenizerMixin
from sglang.srt.server_args import PortArgs, ServerArgs
......@@ -104,7 +103,6 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
(BatchEmbeddingOutput, self.handle_batch_embedding_out),
(BatchTokenIDOutput, self.handle_batch_token_id_out),
(BatchMultimodalDecodeReq, self.handle_multimodal_decode_req),
(MultiTokenizerRegisterReq, lambda x: x),
(FreezeGCReq, self.handle_freeze_gc_req),
]
)
......@@ -227,6 +225,7 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
return BatchStrOutput(
rids=recv_obj.rids,
http_worker_ipcs=recv_obj.http_worker_ipcs,
finished_reasons=recv_obj.finished_reasons,
output_strs=output_strs,
output_ids=recv_obj.decode_ids,
......@@ -258,6 +257,7 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
outputs = self.tokenizer.detokenize(recv_obj)
return BatchMultimodalOutput(
rids=recv_obj.rids,
http_worker_ipcs=recv_obj.http_worker_ipcs,
finished_reasons=recv_obj.finished_reasons,
outputs=outputs,
prompt_tokens=recv_obj.prompt_tokens,
......
......@@ -39,6 +39,7 @@ else:
@dataclass
class BaseReq(ABC):
rid: Optional[Union[str, List[str]]] = field(default=None, kw_only=True)
http_worker_ipc: Optional[str] = field(default=None, kw_only=True)
def regenerate_rid(self):
"""Generate a new request ID and return it."""
......@@ -52,6 +53,7 @@ class BaseReq(ABC):
@dataclass
class BaseBatchReq(ABC):
rids: Optional[List[str]] = field(default=None, kw_only=True)
http_worker_ipcs: Optional[List[str]] = field(default=None, kw_only=True)
def regenerate_rids(self):
"""Generate new request IDs and return them."""
......@@ -1407,18 +1409,6 @@ class LoRAUpdateOutput(BaseReq):
LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateOutput
@dataclass
class MultiTokenizerRegisterReq(BaseBatchReq):
ipc_name: Optional[str] = None
@dataclass
class MultiTokenizerWrapper:
# FIXME(lsyin): remove this
worker_id: int
obj: Optional[Any] = None
class BlockReqType(Enum):
BLOCK = 1
UNBLOCK = 2
......
from __future__ import annotations
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -21,7 +23,7 @@ import sys
import threading
from functools import partialmethod
from multiprocessing import shared_memory
from typing import Any, Dict
from typing import TYPE_CHECKING, Any, Dict, Union
import setproctitle
import zmq
......@@ -30,12 +32,12 @@ import zmq.asyncio
from sglang.srt.disaggregation.utils import DisaggregationMode, TransferBackend
from sglang.srt.managers.disagg_service import start_disagg_service
from sglang.srt.managers.io_struct import (
BaseBatchReq,
BaseReq,
BatchEmbeddingOutput,
BatchMultimodalOutput,
BatchStrOutput,
BatchTokenIDOutput,
MultiTokenizerRegisterReq,
MultiTokenizerWrapper,
)
from sglang.srt.managers.tokenizer_communicator_mixin import _Communicator
from sglang.srt.managers.tokenizer_manager import TokenizerManager
......@@ -43,6 +45,9 @@ from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import get_zmq_socket, kill_process_tree
from sglang.utils import get_exception_traceback
if TYPE_CHECKING:
from sglang.srt.managers.detokenizer_manager import DetokenizerManager
logger = logging.getLogger(__name__)
......@@ -56,29 +61,24 @@ class SocketMapping:
socket.close()
self._mapping.clear()
def register_ipc_mapping(
self, recv_obj: MultiTokenizerRegisterReq, worker_id: str, is_tokenizer: bool
):
def _register_ipc_mapping(self, ipc_name: str, is_tokenizer: bool):
type_str = "tokenizer" if is_tokenizer else "detokenizer"
if worker_id in self._mapping:
logger.warning(
f"{type_str} already registered with worker {worker_id}, skipping..."
)
if ipc_name in self._mapping:
logger.warning(f"{type_str} already registered {ipc_name=}, skipping...")
return
logger.info(
f"{type_str} not registered with worker {worker_id}, registering..."
)
socket = get_zmq_socket(self._zmq_context, zmq.PUSH, recv_obj.ipc_name, False)
self._mapping[worker_id] = socket
self._mapping[worker_id].send_pyobj(recv_obj)
def send_output(self, worker_id: str, output: Any):
if worker_id not in self._mapping:
logger.error(
f"worker ID {worker_id} not registered. Check if the server Process is alive"
)
logger.info(f"Registering {type_str} {ipc_name=} in SocketMapping...")
socket = get_zmq_socket(self._zmq_context, zmq.PUSH, ipc_name, False)
self._mapping[ipc_name] = socket
def send_output(self, ipc_name: str, output: Any):
if ipc_name is None:
# Some unhandled cases
logger.warning(f"IPC name is None, output type={type(output)}, skipping...")
return
self._mapping[worker_id].send_pyobj(output)
if ipc_name not in self._mapping:
self._register_ipc_mapping(ipc_name, is_tokenizer=False)
self._mapping[ipc_name].send_pyobj(output)
def _handle_output_by_index(output, i):
......@@ -362,20 +362,11 @@ def _handle_output_by_index(output, i):
class MultiHttpWorkerDetokenizerMixin:
"""Mixin class for DetokenizerManager"""
def get_worker_ids_from_req_rids(self, rids):
if isinstance(rids, list):
worker_ids = [int(rid.split("_")[0]) for rid in rids]
elif isinstance(rids, str):
worker_ids = [int(rids.split("_")[0])]
else:
worker_ids = []
return worker_ids
def maybe_clear_socket_mapping(self):
def maybe_clear_socket_mapping(self: DetokenizerManager):
if hasattr(self, "socket_mapping"):
self.socket_mapping.clear_all_sockets()
def multi_http_worker_event_loop(self):
def multi_http_worker_event_loop(self: DetokenizerManager):
"""The event loop that handles requests, for multi multi-http-worker mode"""
self.socket_mapping = SocketMapping()
while True:
......@@ -383,23 +374,15 @@ class MultiHttpWorkerDetokenizerMixin:
output = self._request_dispatcher(recv_obj)
if output is None:
continue
# Extract worker_id from rid
if isinstance(recv_obj.rids, list):
worker_ids = self.get_worker_ids_from_req_rids(recv_obj.rids)
else:
raise RuntimeError(
f"for tokenizer_worker_num > 1, recv_obj.rids must be a list"
)
assert isinstance(
recv_obj, BaseBatchReq
), "for multi-http-worker, recv_obj must be BaseBatchReq"
# Send data using the corresponding socket
for i, worker_id in enumerate(worker_ids):
if isinstance(recv_obj, MultiTokenizerRegisterReq):
self.socket_mapping.register_ipc_mapping(
recv_obj, worker_id, is_tokenizer=False
)
else:
new_output = _handle_output_by_index(output, i)
self.socket_mapping.send_output(worker_id, new_output)
for i, ipc_name in enumerate(recv_obj.http_worker_ipcs):
new_output = _handle_output_by_index(output, i)
self.socket_mapping.send_output(ipc_name, new_output)
class MultiTokenizerRouter:
......@@ -449,26 +432,17 @@ class MultiTokenizerRouter:
await self._distribute_result_to_workers(recv_obj)
async def _distribute_result_to_workers(self, recv_obj):
"""Distribute result to corresponding workers based on rid"""
if isinstance(recv_obj, MultiTokenizerWrapper):
worker_ids = [recv_obj.worker_id]
recv_obj = recv_obj.obj
# Distribute result to each worker
if isinstance(recv_obj, BaseReq):
ipc_names = [recv_obj.http_worker_ipc]
elif isinstance(recv_obj, BaseBatchReq):
ipc_names = recv_obj.http_worker_ipcs
else:
worker_ids = self.get_worker_ids_from_req_rids(recv_obj.rids)
if len(worker_ids) == 0:
logger.error(f"Cannot find worker_id from rids {recv_obj.rids}")
return
raise ValueError(f"Unknown recv_obj type: {type(recv_obj)}")
# Distribute result to each worker
for i, worker_id in enumerate(worker_ids):
if isinstance(recv_obj, MultiTokenizerRegisterReq):
self.socket_mapping.register_ipc_mapping(
recv_obj, worker_id, is_tokenizer=True
)
else:
new_recv_obj = _handle_output_by_index(recv_obj, i)
self.socket_mapping.send_output(worker_id, new_recv_obj)
for i, ipc_name in enumerate(ipc_names):
new_recv_obj = _handle_output_by_index(recv_obj, i)
self.socket_mapping.send_output(ipc_name, new_recv_obj)
class TokenizerWorker(TokenizerManager):
......@@ -500,21 +474,15 @@ class TokenizerWorker(TokenizerManager):
self.register_multi_tokenizer_communicator = _Communicator(
self.send_to_scheduler, 2
)
self._result_dispatcher._mapping.append(
(
MultiTokenizerRegisterReq,
self.register_multi_tokenizer_communicator.handle_recv,
)
)
async def register_to_main_tokenizer_manager(self):
"""Register this worker to the main TokenizerManager"""
# create a handle loop to receive messages from the main TokenizerManager
self.auto_create_handle_loop()
req = MultiTokenizerRegisterReq(rids=[f"{self.worker_id}_register"])
req.ipc_name = self.tokenizer_ipc_name
_Communicator.enable_multi_tokenizer = True
await self.register_multi_tokenizer_communicator(req)
def _attach_multi_http_worker_info(self, req: Union[BaseReq, BaseBatchReq]):
if isinstance(req, BaseReq):
req.http_worker_ipc = self.tokenizer_ipc_name
elif isinstance(req, BaseBatchReq):
req.http_worker_ipcs = [self.tokenizer_ipc_name] * len(req.rids)
else:
raise ValueError(f"Unknown req type: {type(req)}")
async def print_exception_wrapper(func):
......
......@@ -438,6 +438,7 @@ class Req:
priority: Optional[int] = None,
metrics_collector: Optional[SchedulerMetricsCollector] = None,
extra_key: Optional[str] = None,
http_worker_ipc: Optional[str] = None,
):
# Input and output info
self.rid = rid
......@@ -461,6 +462,9 @@ class Req:
# The length of KV that have been removed in local attention chunked prefill
self.evicted_seqlen_local = 0
# For multi-http worker
self.http_worker_ipc = http_worker_ipc
# Sampling info
if isinstance(sampling_params.custom_params, dict):
sampling_params = copy.copy(sampling_params)
......
......@@ -24,7 +24,6 @@ from collections import deque
from concurrent import futures
from dataclasses import dataclass
from http import HTTPStatus
from types import SimpleNamespace
from typing import Deque, Dict, List, Optional, Tuple, Union
import psutil
......@@ -66,6 +65,8 @@ from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
from sglang.srt.layers.moe import initialize_moe_config
from sglang.srt.managers.io_struct import (
AbortReq,
BaseBatchReq,
BaseReq,
BatchTokenizedEmbeddingReqInput,
BatchTokenizedGenerateReqInput,
ClearHiCacheReqInput,
......@@ -89,8 +90,6 @@ from sglang.srt.managers.io_struct import (
InitWeightsUpdateGroupReqInput,
LoadLoRAAdapterReqInput,
LoadLoRAAdapterReqOutput,
MultiTokenizerRegisterReq,
MultiTokenizerWrapper,
OpenSessionReqInput,
OpenSessionReqOutput,
ProfileReq,
......@@ -277,47 +276,7 @@ class Scheduler(
self.model_config = ModelConfig.from_server_args(server_args)
# Init inter-process communication
context = zmq.Context(2)
self.idle_sleeper = None
if self.pp_rank == 0 and self.attn_tp_rank == 0:
self.recv_from_tokenizer = get_zmq_socket(
context, zmq.PULL, port_args.scheduler_input_ipc_name, False
)
self.recv_from_rpc = get_zmq_socket(
context, zmq.DEALER, port_args.rpc_ipc_name, False
)
self.send_to_tokenizer = get_zmq_socket(
context, zmq.PUSH, port_args.tokenizer_ipc_name, False
)
if server_args.skip_tokenizer_init:
# Directly send to the TokenizerManager
self.send_to_detokenizer = get_zmq_socket(
context, zmq.PUSH, port_args.tokenizer_ipc_name, False
)
else:
# Send to the DetokenizerManager
self.send_to_detokenizer = get_zmq_socket(
context, zmq.PUSH, port_args.detokenizer_ipc_name, False
)
if self.server_args.sleep_on_idle:
self.idle_sleeper = IdleSleeper(
[
self.recv_from_tokenizer,
self.recv_from_rpc,
]
)
else:
self.recv_from_tokenizer = None
self.recv_from_rpc = None
self.send_to_tokenizer = SimpleNamespace(send_pyobj=lambda x: None)
self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
if self.current_scheduler_metrics_enabled():
self.send_metrics_from_scheduler = get_zmq_socket(
context, zmq.PUSH, port_args.metrics_ipc_name, False
)
self.init_sockets(server_args, port_args)
# Init tokenizer
self.init_tokenizer()
......@@ -578,7 +537,6 @@ class Scheduler(
(ExpertDistributionReq, self.expert_distribution_handle),
(LoadLoRAAdapterReqInput, self.load_lora_adapter),
(UnloadLoRAAdapterReqInput, self.unload_lora_adapter),
(MultiTokenizerRegisterReq, self.register_multi_tokenizer),
(GetLoadReqInput, self.get_load),
]
)
......@@ -634,6 +592,75 @@ class Scheduler(
else:
self.draft_worker = None
def init_sockets(self, server_args: ServerArgs, port_args: PortArgs):
context = zmq.Context(2)
self.idle_sleeper = None
class SenderWrapper:
def __init__(self, socket: zmq.Socket):
self.socket = socket
def send_output(
self,
output: Union[BaseReq, BaseBatchReq],
recv_obj: Optional[Union[BaseReq, BaseBatchReq]] = None,
):
if self.socket is None:
return
if (
isinstance(recv_obj, BaseReq)
and recv_obj.http_worker_ipc is not None
and output.http_worker_ipc is None
):
# handle communicator reqs for multi-http worker case
output.http_worker_ipc = recv_obj.http_worker_ipc
self.socket.send_pyobj(output)
if self.pp_rank == 0 and self.attn_tp_rank == 0:
self.recv_from_tokenizer = get_zmq_socket(
context, zmq.PULL, port_args.scheduler_input_ipc_name, False
)
self.recv_from_rpc = get_zmq_socket(
context, zmq.DEALER, port_args.rpc_ipc_name, False
)
send_to_tokenizer = get_zmq_socket(
context, zmq.PUSH, port_args.tokenizer_ipc_name, False
)
if server_args.skip_tokenizer_init:
# Directly send to the TokenizerManager
send_to_detokenizer = get_zmq_socket(
context, zmq.PUSH, port_args.tokenizer_ipc_name, False
)
else:
# Send to the DetokenizerManager
send_to_detokenizer = get_zmq_socket(
context, zmq.PUSH, port_args.detokenizer_ipc_name, False
)
self.send_to_tokenizer = SenderWrapper(send_to_tokenizer)
self.send_to_detokenizer = SenderWrapper(send_to_detokenizer)
if self.server_args.sleep_on_idle:
self.idle_sleeper = IdleSleeper(
[
self.recv_from_tokenizer,
self.recv_from_rpc,
]
)
else:
self.recv_from_tokenizer = None
self.recv_from_rpc = None
self.send_to_tokenizer = SenderWrapper(None)
self.send_to_detokenizer = SenderWrapper(None)
if self.current_scheduler_metrics_enabled():
self.send_metrics_from_scheduler = get_zmq_socket(
context, zmq.PUSH, port_args.metrics_ipc_name, False
)
def init_deterministic_inference_config(self):
"""Initialize deterministic inference configuration for different attention backends."""
if not self.server_args.enable_deterministic_inference:
......@@ -1107,23 +1134,13 @@ class Scheduler(
self.return_health_check_ct += 1
continue
# If it is a MultiTokenizerWrapper, unwrap it and handle the inner request.
if isinstance(recv_req, MultiTokenizerWrapper):
worker_id = recv_req.worker_id
recv_req = recv_req.obj
output = self._request_dispatcher(recv_req)
if output is not None:
output = MultiTokenizerWrapper(worker_id, output)
self.send_to_tokenizer.send_pyobj(output)
continue
output = self._request_dispatcher(recv_req)
if output is not None:
if isinstance(output, RpcReqOutput):
if self.recv_from_rpc is not None:
self.recv_from_rpc.send_pyobj(output)
else:
self.send_to_tokenizer.send_pyobj(output)
self.send_to_tokenizer.send_output(output, recv_req)
def init_req_max_new_tokens(self, req):
req.sampling_params.max_new_tokens = min(
......@@ -1179,6 +1196,7 @@ class Scheduler(
metrics_collector=(
self.metrics_collector if self.enable_metrics else None
),
http_worker_ipc=recv_req.http_worker_ipc,
)
req.tokenizer = self.tokenizer
......@@ -1382,7 +1400,7 @@ class Scheduler(
},
rid=req.rid,
)
self.send_to_tokenizer.send_pyobj(abort_req)
self.send_to_tokenizer.send_output(abort_req, req)
def _abort_on_queued_limit(self, recv_req: Req) -> bool:
"""Abort an incoming or existing request if the waiting queue is full. Returns True if the incoming request is aborted."""
......@@ -1414,7 +1432,7 @@ class Scheduler(
req_to_abort = candidate_req
message = "The request is aborted by a higher priority request."
self.send_to_tokenizer.send_pyobj(
self.send_to_tokenizer.send_output(
AbortReq(
finished_reason={
"type": "abort",
......@@ -1422,7 +1440,8 @@ class Scheduler(
"message": message,
},
rid=req_to_abort.rid,
)
),
req_to_abort,
)
return req_to_abort.rid == recv_req.rid
......@@ -1437,6 +1456,7 @@ class Scheduler(
recv_req.sampling_params,
token_type_ids=recv_req.token_type_ids,
priority=recv_req.priority,
http_worker_ipc=recv_req.http_worker_ipc,
)
req.tokenizer = self.tokenizer
......@@ -1953,8 +1973,8 @@ class Scheduler(
self.num_retracted_reqs = len(retracted_reqs)
self.new_token_ratio = new_token_ratio
for req in reqs_to_abort:
self.send_to_tokenizer.send_pyobj(
AbortReq(abort_reason=req.to_abort_message, rid=req.rid)
self.send_to_tokenizer.send_output(
AbortReq(abort_reason=req.to_abort_message, rid=req.rid), req
)
logger.info(
......@@ -2138,7 +2158,7 @@ class Scheduler(
# This is used to prevent the health check signal being blocked by long context prefill.
# However, one minor issue is that this code path does not check the status of detokenizer manager.
self.return_health_check_ct -= 1
self.send_to_tokenizer.send_pyobj(HealthCheckOutput())
self.send_to_tokenizer.send_output(HealthCheckOutput())
def prepare_mlp_sync_batch(self, local_batch: ScheduleBatch):
return self.prepare_mlp_sync_batch_raw(
......@@ -2585,7 +2605,7 @@ class Scheduler(
if self.enable_hicache_storage:
# to release prefetch events associated with the request
self.tree_cache.release_aborted_request(req.rid)
self.send_to_tokenizer.send_pyobj(AbortReq(rid=req.rid))
self.send_to_tokenizer.send_output(AbortReq(rid=req.rid), req)
# For disaggregation decode mode, the request in the waiting queue has KV cache allocated.
if self.disaggregation_mode == DisaggregationMode.DECODE:
self.tree_cache.cache_finished_req(req)
......@@ -2669,10 +2689,6 @@ class Scheduler(
result = self.tp_worker.unload_lora_adapter(recv_req)
return result
def register_multi_tokenizer(self, recv_req: MultiTokenizerRegisterReq):
self.send_to_detokenizer.send_pyobj(recv_req)
return recv_req
def init_weights_send_group_for_remote_instance(
self, recv_req: InitWeightsSendGroupForRemoteInstanceReqInput
):
......@@ -2751,7 +2767,7 @@ class Scheduler(
def handle_freeze_gc(self, recv_req: FreezeGCReq):
"""Handle freeze_gc request: freeze scheduler's GC and forward to detokenizer."""
freeze_gc("Scheduler")
self.send_to_detokenizer.send_pyobj(recv_req)
self.send_to_detokenizer.send_output(recv_req, recv_req)
return None
......
......@@ -682,6 +682,7 @@ class SchedulerOutputProcessorMixin:
skip_req: Optional[Req] = None,
):
rids = []
http_worker_ipcs = []
finished_reasons: List[BaseFinishReason] = []
decoded_texts = []
......@@ -770,6 +771,7 @@ class SchedulerOutputProcessorMixin:
req.send_output_token_logprobs_offset
)
rids.append(req.rid)
http_worker_ipcs.append(req.http_worker_ipc)
finished_reasons.append(
req.finished_reason.to_json() if req.finished_reason else None
)
......@@ -886,7 +888,7 @@ class SchedulerOutputProcessorMixin:
if self.model_config.is_multimodal_gen:
return
self.send_to_detokenizer.send_pyobj(
self.send_to_detokenizer.send_output(
BatchTokenIDOutput(
finished_reasons,
decoded_texts,
......@@ -916,6 +918,7 @@ class SchedulerOutputProcessorMixin:
output_token_entropy_val=None,
output_hidden_states=output_hidden_states,
rids=rids,
http_worker_ipcs=http_worker_ipcs,
placeholder_tokens_idx=None,
placeholder_tokens_val=None,
)
......@@ -923,6 +926,7 @@ class SchedulerOutputProcessorMixin:
def stream_output_embedding(self: Scheduler, reqs: List[Req]):
rids = []
http_worker_ipcs = []
finished_reasons: List[BaseFinishReason] = []
embeddings = []
......@@ -931,17 +935,19 @@ class SchedulerOutputProcessorMixin:
for req in reqs:
if req.finished():
rids.append(req.rid)
http_worker_ipcs.append(req.http_worker_ipc)
finished_reasons.append(req.finished_reason.to_json())
embeddings.append(req.embedding)
prompt_tokens.append(len(req.origin_input_ids))
cached_tokens.append(req.cached_tokens)
self.send_to_detokenizer.send_pyobj(
self.send_to_detokenizer.send_output(
BatchEmbeddingOutput(
finished_reasons,
embeddings,
prompt_tokens,
cached_tokens,
rids=rids,
http_worker_ipcs=http_worker_ipcs,
placeholder_tokens_idx=None,
placeholder_tokens_val=None,
)
......
......@@ -3,7 +3,6 @@ from __future__ import annotations
import asyncio
import copy
import logging
import os
import time
import uuid
from collections import deque
......@@ -46,7 +45,6 @@ from sglang.srt.managers.io_struct import (
LoadLoRAAdapterReqInput,
LoadLoRAAdapterReqOutput,
LoRAUpdateOutput,
MultiTokenizerWrapper,
OpenSessionReqInput,
ProfileReq,
ProfileReqOutput,
......@@ -83,8 +81,6 @@ logger = logging.getLogger(__name__)
class _Communicator(Generic[T]):
"""Note: The communicator now only run up to 1 in-flight request at any time."""
enable_multi_tokenizer = False
def __init__(self, sender: zmq.Socket, fan_out: int, mode="queueing"):
self._sender = sender
self._fan_out = fan_out
......@@ -104,8 +100,6 @@ class _Communicator(Generic[T]):
assert self._result_values is None
if obj:
if _Communicator.enable_multi_tokenizer:
obj = MultiTokenizerWrapper(worker_id=os.getpid(), obj=obj)
self._sender.send_pyobj(obj)
self._result_event = asyncio.Event()
......@@ -126,8 +120,6 @@ class _Communicator(Generic[T]):
self._result_event = asyncio.Event()
if obj:
if _Communicator.enable_multi_tokenizer:
obj = MultiTokenizerWrapper(worker_id=os.getpid(), obj=obj)
self._sender.send_pyobj(obj)
await self._result_event.wait()
......@@ -617,8 +609,6 @@ class TokenizerCommunicatorMixin:
elif obj.session_id in self.session_futures:
return None
if self.server_args.tokenizer_worker_num > 1:
obj = MultiTokenizerWrapper(self.worker_id, obj)
self.send_to_scheduler.send_pyobj(obj)
self.session_futures[obj.session_id] = asyncio.Future()
......
......@@ -46,6 +46,7 @@ from sglang.srt.managers.async_dynamic_batch_tokenizer import AsyncDynamicbatchT
from sglang.srt.managers.disagg_service import start_disagg_service
from sglang.srt.managers.io_struct import (
AbortReq,
BaseReq,
BatchEmbeddingOutput,
BatchMultimodalOutput,
BatchStrOutput,
......@@ -58,7 +59,6 @@ from sglang.srt.managers.io_struct import (
GenerateReqInput,
GetLoadReqInput,
HealthCheckOutput,
MultiTokenizerWrapper,
OpenSessionReqOutput,
SessionParams,
TokenizedEmbeddingReqInput,
......@@ -88,7 +88,6 @@ from sglang.srt.utils import (
dataclass_to_string_truncated,
freeze_gc,
get_bool_env_var,
get_origin_rid,
get_zmq_socket,
kill_process_tree,
)
......@@ -258,9 +257,18 @@ class TokenizerManager(TokenizerCommunicatorMixin):
)
if self.server_args.tokenizer_worker_num > 1:
# Use tokenizer_worker_ipc_name in multi-tokenizer mode
self.send_to_scheduler = get_zmq_socket(
send_to_scheduler = get_zmq_socket(
context, zmq.PUSH, port_args.tokenizer_worker_ipc_name, False
)
class SenderWrapper:
def send_pyobj(self, obj):
if isinstance(obj, BaseReq):
obj.http_worker_ipc = port_args.tokenizer_ipc_name
send_to_scheduler.send_pyobj(obj)
# Make sure that each request carries the tokenizer_ipc_name for response routing
self.send_to_scheduler = SenderWrapper()
else:
self.send_to_scheduler = get_zmq_socket(
context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
......@@ -376,13 +384,10 @@ class TokenizerManager(TokenizerCommunicatorMixin):
obj.normalize_batch_and_arguments()
if self.server_args.tokenizer_worker_num > 1:
# Modify rid, add worker_id
if isinstance(obj.rid, list):
# If it's an array, add worker_id prefix to each element
obj.rid = [f"{self.worker_id}_{rid}" for rid in obj.rid]
else:
# If it's a single value, add worker_id prefix
obj.rid = f"{self.worker_id}_{obj.rid}"
from sglang.srt.managers.multi_tokenizer_mixin import TokenizerWorker
assert isinstance(self, TokenizerWorker)
self._attach_multi_http_worker_info(obj)
if self.enable_trace:
self._trace_request_start(obj, created_time)
......@@ -728,6 +733,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
obj.token_ids_logprob,
obj.stream,
rid=obj.rid,
http_worker_ipc=obj.http_worker_ipc,
bootstrap_host=obj.bootstrap_host,
bootstrap_port=obj.bootstrap_port,
bootstrap_room=obj.bootstrap_room,
......@@ -749,6 +755,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
sampling_params,
rid=obj.rid,
priority=obj.priority,
http_worker_ipc=obj.http_worker_ipc,
)
return tokenized_obj
......@@ -1109,8 +1116,6 @@ class TokenizerManager(TokenizerCommunicatorMixin):
async def _wait_for_model_update_from_disk(
self, obj: UpdateWeightFromDiskReqInput
) -> Tuple[bool, str]:
if self.server_args.tokenizer_worker_num > 1:
obj = MultiTokenizerWrapper(self.worker_id, obj)
self.send_to_scheduler.send_pyobj(obj)
self.model_update_result = asyncio.Future()
if self.server_args.dp_size == 1:
......@@ -1349,12 +1354,9 @@ class TokenizerManager(TokenizerCommunicatorMixin):
)
continue
origin_rid = rid
if self.server_args.tokenizer_worker_num > 1:
origin_rid = get_origin_rid(rid)
# Build meta_info and return value
meta_info = {
"id": origin_rid,
"id": rid,
"finish_reason": recv_obj.finished_reasons[i],
"prompt_tokens": recv_obj.prompt_tokens[i],
"weight_version": self.server_args.weight_version,
......@@ -1708,9 +1710,6 @@ class TokenizerManager(TokenizerCommunicatorMixin):
if is_health_check_generate_req(recv_obj):
return
state = self.rid_to_state[recv_obj.rid]
origin_rid = recv_obj.rid
if self.server_args.tokenizer_worker_num > 1:
origin_rid = get_origin_rid(origin_rid)
state.finished = True
if recv_obj.finished_reason:
out = {
......@@ -1723,7 +1722,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
out = {
"text": "",
"meta_info": {
"id": origin_rid,
"id": recv_obj.rid,
"finish_reason": {
"type": "abort",
"message": "Abort before prefill",
......
......@@ -3006,10 +3006,6 @@ def lru_cache_frozenset(maxsize=128):
return decorator
def get_origin_rid(rid):
return rid.split("_", 1)[1] if "_" in rid else rid
def apply_module_patch(target_module, target_function, wrappers):
original_module, original_function = parse_module_path(
target_module, target_function, False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment