Unverified Commit 5f77e129 authored by ybyang's avatar ybyang Committed by GitHub
Browse files

Support Multi Process Tokenizer Manager(#6555) (#8964)


Signed-off-by: default avatarybyang <ybyang7@iflytek.com>
Signed-off-by: default avatarhuanglong <huanglong@linux.alibaba.com>
Co-authored-by: default avatarHuang Long <121648372+LLLL114@users.noreply.github.com>
Co-authored-by: default avatarhuanglong <huanglong@linux.alibaba.com>
Co-authored-by: default avatarShangming Cai <csmthu@gmail.com>
parent 4750cddf
......@@ -60,6 +60,7 @@ from sglang.srt.managers.io_struct import (
UpdateWeightsFromDistributedReqInput,
UpdateWeightsFromTensorReqInput,
)
from sglang.srt.managers.multi_tokenizer_mixin import MultiTokenizerRouter
from sglang.srt.managers.scheduler import run_scheduler_process
from sglang.srt.managers.template_manager import TemplateManager
from sglang.srt.managers.tokenizer_manager import TokenizerManager
......@@ -814,18 +815,24 @@ def _launch_subprocesses(
),
)
detoken_proc.start()
if server_args.tokenizer_worker_num > 1:
# Launch multi-tokenizer router
tokenizer_manager = MultiTokenizerRouter(server_args, port_args)
# Launch tokenizer process
tokenizer_manager = TokenizerManager(server_args, port_args)
# Initialize templates
template_manager = None
else:
# Launch tokenizer process
tokenizer_manager = TokenizerManager(server_args, port_args)
# Initialize templates
template_manager = TemplateManager()
template_manager.initialize_templates(
tokenizer_manager=tokenizer_manager,
model_path=server_args.model_path,
chat_template=server_args.chat_template,
completion_template=server_args.completion_template,
)
# Initialize templates
template_manager = TemplateManager()
template_manager.initialize_templates(
tokenizer_manager=tokenizer_manager,
model_path=server_args.model_path,
chat_template=server_args.chat_template,
completion_template=server_args.completion_template,
)
# Wait for the model to finish loading
scheduler_infos = []
......
......@@ -23,6 +23,7 @@ import json
import logging
import multiprocessing as multiprocessing
import os
import tempfile
import threading
import time
from http import HTTPStatus
......@@ -91,11 +92,18 @@ from sglang.srt.managers.io_struct import (
UpdateWeightVersionReqInput,
VertexGenerateReqInput,
)
from sglang.srt.managers.multi_tokenizer_mixin import (
MultiTokenizerManager,
deserialize_data,
get_main_process_id,
read_from_shared_memory,
write_data_for_multi_tokenizer,
)
from sglang.srt.managers.template_manager import TemplateManager
from sglang.srt.managers.tokenizer_manager import ServerStatus, TokenizerManager
from sglang.srt.metrics.func_timer import enable_func_timer
from sglang.srt.reasoning_parser import ReasoningParser
from sglang.srt.server_args import ServerArgs
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import (
add_api_key_middleware,
add_prometheus_middleware,
......@@ -130,8 +138,79 @@ def set_global_state(global_state: _GlobalState):
_global_state = global_state
# Function to set up all middlewares for multi-tokenizer compatibility
def setup_middlewares(api_key: Optional[str], enable_metrics: bool):
"""Setup all middlewares for both single and multi-process modes"""
worker_pid = os.getpid()
if api_key:
add_api_key_middleware(app, api_key)
logger.info(f"Worker {worker_pid} added API key middleware")
if enable_metrics:
add_prometheus_middleware(app)
enable_func_timer()
logger.info(f"Worker {worker_pid} added prometheus middleware")
async def init_multi_tokenizer() -> ServerArgs:
"""Read args information from shm and init tokenizer manager for current process"""
pid = os.getpid()
main_pid = get_main_process_id()
logger.info(f"current worker_id: {pid}, main processID: {main_pid}")
# Read configuration from shared memory
port_args_data = read_from_shared_memory(f"port_args_{main_pid}")
server_args_data = read_from_shared_memory(f"server_args_{main_pid}")
scheduler_info_data = read_from_shared_memory(f"scheduler_info_{main_pid}")
port_args, server_args = deserialize_data(port_args_data, server_args_data)
scheduler_info = scheduler_info_data
port_args.tokenizer_ipc_name = (
f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}"
)
# Launch multi-tokenizer manager process
tokenizer_manager = MultiTokenizerManager(server_args, port_args)
template_manager = TemplateManager()
template_manager.initialize_templates(
tokenizer_manager=tokenizer_manager,
model_path=server_args.model_path,
chat_template=server_args.chat_template,
completion_template=server_args.completion_template,
)
# Register this tokenizer with the main tokenizer manager
await tokenizer_manager.register_to_main_tokenizer_manager()
tokenizer_manager.max_req_input_len = scheduler_info["max_req_input_len"]
set_global_state(
_GlobalState(
tokenizer_manager=tokenizer_manager,
template_manager=template_manager,
scheduler_info=scheduler_info,
)
)
return server_args
@asynccontextmanager
async def lifespan(fast_api_app: FastAPI):
server_args = getattr(fast_api_app, "server_args", None)
if server_args is None:
# Initialize multi-tokenizer support for worker processes
fast_api_app.server_args = await init_multi_tokenizer()
setup_middlewares(
fast_api_app.server_args.api_key, fast_api_app.server_args.enable_metrics
)
fast_api_app.warmup_thread = threading.Thread(
target=_wait_and_warmup,
args=(
fast_api_app.server_args,
None, # pipe_finish_writer not needed in worker
None, # launch_callback not needed in worker
),
)
# Initialize OpenAI serving handlers
fast_api_app.state.openai_serving_completion = OpenAIServingCompletion(
_global_state.tokenizer_manager, _global_state.template_manager
......@@ -191,7 +270,15 @@ async def lifespan(fast_api_app: FastAPI):
warmup_thread = getattr(fast_api_app, "warmup_thread", None)
if warmup_thread is not None:
warmup_thread.start()
yield
try:
yield
finally:
if server_args.tokenizer_worker_num > 1:
pid = os.getpid()
logger.info(f"uvicorn worker {pid} ending...")
warmup_thread.join()
logger.info(f"uvicorn worker {pid} ended.")
# Fast API
......@@ -1078,9 +1165,19 @@ def launch_server(
1. The HTTP server, Engine, and TokenizerManager both run in the main process.
2. Inter-process communication is done through IPC (each process uses a different port) via the ZMQ library.
"""
tokenizer_manager, template_manager, scheduler_info = _launch_subprocesses(
server_args=server_args
)
if server_args.tokenizer_worker_num > 1:
port_args = PortArgs.init_new(server_args)
port_args.tokenizer_worker_ipc_name = (
f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}"
)
tokenizer_manager, template_manager, scheduler_info = _launch_subprocesses(
server_args=server_args, port_args=port_args
)
else:
tokenizer_manager, template_manager, scheduler_info = _launch_subprocesses(
server_args=server_args,
)
set_global_state(
_GlobalState(
tokenizer_manager=tokenizer_manager,
......@@ -1089,42 +1186,75 @@ def launch_server(
)
)
# Add api key authorization
if server_args.api_key:
add_api_key_middleware(app, server_args.api_key)
# Add prometheus middleware
if server_args.enable_metrics:
add_prometheus_middleware(app)
enable_func_timer()
# Send a warmup request - we will create the thread launch it
# in the lifespan after all other warmups have fired.
warmup_thread = threading.Thread(
target=_wait_and_warmup,
args=(
server_args,
pipe_finish_writer,
launch_callback,
),
)
app.warmup_thread = warmup_thread
if server_args.tokenizer_worker_num > 1:
port_args_shm, server_args_shm, scheduler_info_shm = (
write_data_for_multi_tokenizer(
port_args,
server_args,
scheduler_info,
)
)
else:
# Add api key authorization
if server_args.api_key:
add_api_key_middleware(app, server_args.api_key)
# Add prometheus middleware
if server_args.enable_metrics:
add_prometheus_middleware(app)
enable_func_timer()
# Send a warmup request - we will create the thread launch it
# in the lifespan after all other warmups have fired.
warmup_thread = threading.Thread(
target=_wait_and_warmup,
args=(
server_args,
pipe_finish_writer,
launch_callback,
),
)
app.warmup_thread = warmup_thread
try:
# Update logging configs
set_uvicorn_logging_configs()
app.server_args = server_args
# Listen for HTTP requests
uvicorn.run(
app,
host=server_args.host,
port=server_args.port,
log_level=server_args.log_level_http or server_args.log_level,
timeout_keep_alive=5,
loop="uvloop",
)
if server_args.tokenizer_worker_num > 1:
from uvicorn.config import LOGGING_CONFIG
LOGGING_CONFIG["loggers"]["sglang.srt.entrypoints.http_server"] = {
"handlers": ["default"],
"level": "INFO",
"propagate": False,
}
uvicorn.run(
"sglang.srt.entrypoints.http_server:app",
host=server_args.host,
port=server_args.port,
log_level=server_args.log_level_http or server_args.log_level,
timeout_keep_alive=5,
loop="uvloop",
workers=server_args.tokenizer_worker_num,
)
else:
uvicorn.run(
app,
host=server_args.host,
port=server_args.port,
log_level=server_args.log_level_http or server_args.log_level,
timeout_keep_alive=5,
loop="uvloop",
)
finally:
warmup_thread.join()
if server_args.tokenizer_worker_num > 1:
port_args_shm.unlink()
server_args_shm.unlink()
scheduler_info_shm.unlink()
_global_state.tokenizer_manager.clear_tokenizer_mapping()
else:
warmup_thread.join()
def _execute_server_warmup(
......
......@@ -32,11 +32,14 @@ from sglang.srt.managers.io_struct import (
BatchStrOut,
BatchTokenIDOut,
FreezeGCReq,
MultiTokenizerRegisterReq,
)
from sglang.srt.managers.multi_tokenizer_mixin import MultiTokenizerMixin
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import (
configure_logger,
freeze_gc,
get_worker_ids_from_req_rids,
get_zmq_socket,
kill_itself_when_parent_died,
)
......@@ -67,7 +70,7 @@ class DecodeStatus:
sent_offset: int = 0
class DetokenizerManager:
class DetokenizerManager(MultiTokenizerMixin):
"""DetokenizerManager is a process that detokenizes the token ids."""
def __init__(
......@@ -102,6 +105,7 @@ class DetokenizerManager:
(BatchEmbeddingOut, self.handle_batch_embedding_out),
(BatchTokenIDOut, self.handle_batch_token_id_out),
(BatchMultimodalDecodeReq, self.handle_multimodal_decode_req),
(MultiTokenizerRegisterReq, lambda x: x),
(FreezeGCReq, self.handle_freeze_gc_req),
]
)
......@@ -116,6 +120,39 @@ class DetokenizerManager:
if output is not None:
self.send_to_tokenizer.send_pyobj(output)
def multi_tokenizer_manager_event_loop(self):
"""The event loop that handles requests, for multi tokenizer manager mode only"""
self.create_sockets_mapping()
while True:
recv_obj = self.recv_from_scheduler.recv_pyobj()
output = self._request_dispatcher(recv_obj)
if output is None:
continue
# Extract worker_id from rid
if isinstance(recv_obj.rids, list):
worker_ids = get_worker_ids_from_req_rids(recv_obj.rids)
else:
raise RuntimeError(
f"for tokenizer_worker_num > 1, recv_obj.rids must be a list"
)
# Send data using the corresponding socket
for i, worker_id in enumerate(worker_ids):
if isinstance(recv_obj, MultiTokenizerRegisterReq):
if self.register_tokenizer_ipc(recv_obj, worker_id):
logger.info(
f"DetokenizerManager Created ZMQ socket for worker {worker_id}"
)
continue
else:
if worker_id not in self.tokenizer_mapping:
logger.error(
f"Tokenizer Worker ID {worker_id} not registered. Check if the server Process {worker_id} is alive"
)
continue
new_output = self._handle_output_by_index(output, i)
self.tokenizer_mapping[worker_id].send_pyobj(new_output)
def trim_matched_stop(
self, output: Union[str, List[int]], finished_reason: Dict, no_stop_trim: bool
):
......@@ -285,8 +322,12 @@ def run_detokenizer_process(
try:
manager = DetokenizerManager(server_args, port_args)
manager.event_loop()
if server_args.tokenizer_worker_num > 1:
manager.multi_tokenizer_manager_event_loop()
else:
manager.event_loop()
except Exception:
manager.clear_tokenizer_mapping()
traceback = get_exception_traceback()
logger.error(f"DetokenizerManager hit an exception: {traceback}")
parent_process.send_signal(signal.SIGQUIT)
......@@ -983,6 +983,11 @@ class AbortReq:
abort_all: bool = False
# The finished reason data
finished_reason: Optional[Dict[str, Any]] = None
# used in MultiTokenzierManager mode
rids: Optional[Union[List[str], str]] = None
def __post_init__(self):
self.rids = self.rid
@dataclass
......@@ -1183,6 +1188,18 @@ class LoRAUpdateResult:
LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult
@dataclass
class MultiTokenizerRegisterReq:
rids: Optional[Union[List[str], str]] = None
ipc_name: Optional[str] = None
@dataclass
class MultiTokenizerWarpper:
worker_id: int
obj: Optional[Any] = None
class BlockReqType(Enum):
BLOCK = 1
UNBLOCK = 2
......
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""MultiTokenizerMixin is a class that provides nesscary methods for MultiTokenizerManager and DetokenizerManager."""
import asyncio
import dataclasses
import json
import logging
import multiprocessing as multiprocessing
import os
import sys
import threading
from multiprocessing import shared_memory
from typing import Dict
import zmq
import zmq.asyncio
from sglang.srt.disaggregation.utils import DisaggregationMode, TransferBackend
from sglang.srt.managers.io_struct import (
BatchEmbeddingOut,
BatchMultimodalOut,
BatchStrOut,
BatchTokenIDOut,
MultiTokenizerRegisterReq,
MultiTokenizerWarpper,
)
from sglang.srt.managers.tokenizer_manager import TokenizerManager, _Communicator
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import (
get_worker_ids_from_req_rids,
get_zmq_socket,
kill_process_tree,
)
from sglang.utils import get_exception_traceback
logger = logging.getLogger(__name__)
class MultiTokenizerMixin:
"""Mixin class for MultiTokenizerManager and DetokenizerManager"""
def create_sockets_mapping(self):
if not hasattr(self, "tokenizer_mapping"):
self.tokenizer_mapping = {}
# Create ZMQ context if needed
if not hasattr(self, "_zmq_context"):
self._zmq_context = zmq.Context()
def init_tokenizer_mapping(
self, recv_obj: MultiTokenizerRegisterReq, worker_id: str
):
"""init tokenizer mapping from register request"""
ipc_name = recv_obj.ipc_name
worker_id_int = int(worker_id)
if worker_id_int not in self.tokenizer_mapping:
socket = get_zmq_socket(self._zmq_context, zmq.PUSH, ipc_name, False)
self.tokenizer_mapping[worker_id_int] = socket
self.tokenizer_mapping[worker_id_int].send_pyobj(recv_obj)
return True
else:
return False
def register_tokenizer_ipc(self, recv_obj, worker_id):
if worker_id not in self.tokenizer_mapping:
# register the worker if not already done
if isinstance(recv_obj, MultiTokenizerRegisterReq):
return self.init_tokenizer_mapping(recv_obj, worker_id)
else:
logger.error(
f"Worker {worker_id} not registered and not found in tokenizer mapping . "
"Please ensure the worker is registered correctly."
)
return False
def _handle_output_by_index(self, output, i):
"""NOTE: A maintainable method is better here."""
if isinstance(output, BatchTokenIDOut):
new_output = BatchTokenIDOut(
rids=[output.rids[i]],
finished_reasons=(
[output.finished_reasons[i]]
if len(output.finished_reasons) > i
else None
),
decoded_texts=(
[output.decoded_texts[i]] if len(output.decoded_texts) > i else None
),
decode_ids=(
[output.decode_ids[i]] if len(output.decode_ids) > i else None
),
read_offsets=(
[output.read_offsets[i]] if len(output.read_offsets) > i else None
),
output_ids=(
[output.output_ids[i]]
if output.output_ids and len(output.output_ids) > i
else None
),
skip_special_tokens=(
[output.skip_special_tokens[i]]
if len(output.skip_special_tokens) > i
else None
),
spaces_between_special_tokens=(
[output.spaces_between_special_tokens[i]]
if len(output.spaces_between_special_tokens) > i
else None
),
no_stop_trim=(
[output.no_stop_trim[i]] if len(output.no_stop_trim) > i else None
),
prompt_tokens=(
[output.prompt_tokens[i]] if len(output.prompt_tokens) > i else None
),
completion_tokens=(
[output.completion_tokens[i]]
if len(output.completion_tokens) > i
else None
),
cached_tokens=(
[output.cached_tokens[i]] if len(output.cached_tokens) > i else None
),
spec_verify_ct=(
[output.spec_verify_ct[i]]
if len(output.spec_verify_ct) > i
else None
),
input_token_logprobs_val=(
[output.input_token_logprobs_val[i]]
if output.input_token_logprobs_val
else None
),
input_token_logprobs_idx=(
[output.input_token_logprobs_idx[i]]
if output.input_token_logprobs_idx
else None
),
output_token_logprobs_val=(
[output.output_token_logprobs_val[i]]
if output.output_token_logprobs_val
else None
),
output_token_logprobs_idx=(
[output.output_token_logprobs_idx[i]]
if output.output_token_logprobs_idx
else None
),
input_top_logprobs_val=(
[output.input_top_logprobs_val[i]]
if output.input_top_logprobs_val
else None
),
input_top_logprobs_idx=(
[output.input_top_logprobs_idx[i]]
if output.input_top_logprobs_idx
else None
),
output_top_logprobs_val=(
[output.output_top_logprobs_val[i]]
if output.output_top_logprobs_val
else None
),
output_top_logprobs_idx=(
[output.output_top_logprobs_idx[i]]
if output.output_top_logprobs_idx
else None
),
input_token_ids_logprobs_val=(
[output.input_token_ids_logprobs_val[i]]
if output.input_token_ids_logprobs_val
else None
),
input_token_ids_logprobs_idx=(
[output.input_token_ids_logprobs_idx[i]]
if output.input_token_ids_logprobs_idx
else None
),
output_token_ids_logprobs_val=(
[output.output_token_ids_logprobs_val[i]]
if output.output_token_ids_logprobs_val
else None
),
output_token_ids_logprobs_idx=(
[output.output_token_ids_logprobs_idx[i]]
if output.output_token_ids_logprobs_idx
else None
),
output_hidden_states=(
[output.output_hidden_states[i]]
if output.output_hidden_states
else None
),
)
elif isinstance(output, BatchEmbeddingOut):
new_output = BatchEmbeddingOut(
rids=[output.rids[i]],
finished_reasons=(
[output.finished_reasons[i]]
if len(output.finished_reasons) > i
else None
),
embeddings=(
[output.embeddings[i]] if len(output.embeddings) > i else None
),
prompt_tokens=(
[output.prompt_tokens[i]] if len(output.prompt_tokens) > i else None
),
cached_tokens=(
[output.cached_tokens[i]] if len(output.cached_tokens) > i else None
),
)
elif isinstance(output, BatchStrOut):
new_output = BatchStrOut(
rids=[output.rids[i]],
finished_reasons=(
[output.finished_reasons[i]]
if len(output.finished_reasons) > i
else None
),
output_strs=(
[output.output_strs[i]] if len(output.output_strs) > i else None
),
output_ids=(
[output.output_ids[i]]
if output.output_ids and len(output.output_ids) > i
else None
),
prompt_tokens=(
[output.prompt_tokens[i]] if len(output.prompt_tokens) > i else None
),
completion_tokens=(
[output.completion_tokens[i]]
if len(output.completion_tokens) > i
else None
),
cached_tokens=(
[output.cached_tokens[i]] if len(output.cached_tokens) > i else None
),
spec_verify_ct=(
[output.spec_verify_ct[i]]
if len(output.spec_verify_ct) > i
else None
),
input_token_logprobs_val=(
[output.input_token_logprobs_val[i]]
if output.input_token_logprobs_val
else None
),
input_token_logprobs_idx=(
[output.input_token_logprobs_idx[i]]
if output.input_token_logprobs_idx
else None
),
output_token_logprobs_val=(
[output.output_token_logprobs_val[i]]
if output.output_token_logprobs_val
else None
),
output_token_logprobs_idx=(
[output.output_token_logprobs_idx[i]]
if output.output_token_logprobs_idx
else None
),
input_top_logprobs_val=(
[output.input_top_logprobs_val[i]]
if output.input_top_logprobs_val
else None
),
input_top_logprobs_idx=(
[output.input_top_logprobs_idx[i]]
if output.input_top_logprobs_idx
else None
),
output_top_logprobs_val=(
[output.output_top_logprobs_val[i]]
if output.output_top_logprobs_val
else None
),
output_top_logprobs_idx=(
[output.output_top_logprobs_idx[i]]
if output.output_top_logprobs_idx
else None
),
input_token_ids_logprobs_val=(
[output.input_token_ids_logprobs_val[i]]
if output.input_token_ids_logprobs_val
else None
),
input_token_ids_logprobs_idx=(
[output.input_token_ids_logprobs_idx[i]]
if output.input_token_ids_logprobs_idx
else None
),
output_token_ids_logprobs_val=(
[output.output_token_ids_logprobs_val[i]]
if output.output_token_ids_logprobs_val
else None
),
output_token_ids_logprobs_idx=(
[output.output_token_ids_logprobs_idx[i]]
if output.output_token_ids_logprobs_idx
else None
),
output_hidden_states=(
[output.output_hidden_states[i]]
if output.output_hidden_states
else None
),
)
elif isinstance(output, BatchMultimodalOut):
new_output = BatchMultimodalOut(
rids=[output.rids[i]],
finished_reasons=(
[output.finished_reasons[i]]
if len(output.finished_reasons) > i
else None
),
outputs=([output.outputs[i]] if len(output.outputs) > i else None),
prompt_tokens=(
[output.prompt_tokens[i]] if len(output.prompt_tokens) > i else None
),
completion_tokens=(
[output.completion_tokens[i]]
if len(output.completion_tokens) > i
else None
),
cached_tokens=(
[output.cached_tokens[i]] if len(output.cached_tokens) > i else None
),
)
else:
new_output = output
return new_output
def clear_tokenizer_mapping(self):
if hasattr(self, "tokenizer_mapping"):
for socket in self.tokenizer_mapping.values():
try:
socket.close()
except Exception as e:
logger.warning(f"Failed to close socket: {e}")
self.tokenizer_mapping.clear()
class MultiTokenizerRouter(TokenizerManager, MultiTokenizerMixin):
"""A router to receive requests from MultiTokenizerManager"""
def __init__(
self,
server_args: ServerArgs,
port_args: PortArgs,
):
self.server_args = server_args
context = zmq.asyncio.Context(3)
self.recv_from_detokenizer = get_zmq_socket(
context, zmq.PULL, port_args.tokenizer_ipc_name, True
)
self.send_to_scheduler = get_zmq_socket(
context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
)
self.receive_from_worker = get_zmq_socket(
context, zmq.PULL, port_args.tokenizer_worker_ipc_name, True
)
self._loop = asyncio.new_event_loop()
self._thread = threading.Thread(target=self._run_loop, daemon=True)
self._thread.start()
self._task = asyncio.run_coroutine_threadsafe(
self.router_worker_obj(), self._loop
)
# Start handle_loop simultaneously
self._handle_task = asyncio.run_coroutine_threadsafe(
print_exception_wrapper(self.handle_loop), self._loop
)
self.init_disaggregation()
def _run_loop(self):
self._loop.run_forever()
async def router_worker_obj(self):
while True:
recv_obj = await self.receive_from_worker.recv_pyobj()
await self.send_to_scheduler.send_pyobj(recv_obj)
async def handle_loop(self):
# special reqs will recv from scheduler, need to route to right worker
self.create_sockets_mapping()
while True:
recv_obj = await self.recv_from_detokenizer.recv_pyobj()
await self._distribute_result_to_workers(recv_obj)
async def _distribute_result_to_workers(self, recv_obj):
"""Distribute result to corresponding workers based on rid"""
if isinstance(recv_obj, MultiTokenizerWarpper):
worker_ids = [recv_obj.worker_id]
recv_obj = recv_obj.obj
else:
worker_ids = get_worker_ids_from_req_rids(recv_obj.rids)
if len(worker_ids) == 0:
logger.error(f"Cannot find worker_id from rids {recv_obj.rids}")
return
# Distribute result to each worker
for i, worker_id in enumerate(worker_ids):
if isinstance(recv_obj, MultiTokenizerRegisterReq):
if self.register_tokenizer_ipc(recv_obj, worker_id):
logger.info(
f"MultiTokenizerRouter Created ZMQ socket for worker {worker_id}"
)
continue
else:
if worker_id not in self.tokenizer_mapping:
logger.error(
f"Tokenizer Worker ID {worker_id} not registered. Check if the server Process {worker_id} is alive"
)
continue
new_recv_obj = self._handle_output_by_index(recv_obj, i)
self.tokenizer_mapping[worker_id].send_pyobj(new_recv_obj)
class MultiTokenizerManager(TokenizerManager, MultiTokenizerMixin):
"""Multi Process Tokenizer Manager that tokenizes the text."""
def __init__(
self,
server_args: ServerArgs,
port_args: PortArgs,
):
# prevent init prefill bootstrapserver again
disaggregation_mode = server_args.disaggregation_mode
server_args.disaggregation_mode = "null"
super().__init__(server_args, port_args)
self.worker_id = os.getpid()
self.tokenizer_ipc_name = port_args.tokenizer_ipc_name
# For PD disaggregtion
self.server_args.disaggregation_mode = disaggregation_mode
self.disaggregation_mode = DisaggregationMode(
self.server_args.disaggregation_mode
)
self.disaggregation_transfer_backend = TransferBackend(
self.server_args.disaggregation_transfer_backend
)
# Communicator
self.register_multi_tokenizer_communicator = _Communicator(
self.send_to_scheduler, 2
)
self._result_dispatcher._mapping.append(
(
MultiTokenizerRegisterReq,
self.register_multi_tokenizer_communicator.handle_recv,
)
)
async def register_to_main_tokenizer_manager(self):
"""Register this worker to the main TokenizerManager"""
# create a handle loop to receive messages from the main TokenizerManager
self.auto_create_handle_loop()
req = MultiTokenizerRegisterReq(rids=[f"{self.worker_id}_register"])
req.ipc_name = self.tokenizer_ipc_name
_Communicator.enable_multi_tokenizer = True
await self.register_multi_tokenizer_communicator(req)
async def print_exception_wrapper(func):
"""
Sometimes an asyncio function does not print exception.
We do another wrapper to handle the exception.
"""
try:
await func()
except Exception:
traceback = get_exception_traceback()
logger.error(f"MultiTokenizerRouter hit an exception: {traceback}")
if hasattr(func, "__self__") and isinstance(
func.__self__, MultiTokenizerRouter
):
func.__self__.dump_requests_before_crash()
kill_process_tree(os.getpid(), include_parent=True)
sys.exit(1)
def serialize_port_args(port_args: PortArgs) -> dict:
"""Serialize PortArgs into a shareable dictionary"""
return {
"tokenizer_ipc_name": port_args.tokenizer_ipc_name,
"scheduler_input_ipc_name": port_args.scheduler_input_ipc_name,
"detokenizer_ipc_name": port_args.detokenizer_ipc_name,
"nccl_port": port_args.nccl_port,
"rpc_ipc_name": port_args.rpc_ipc_name,
"metrics_ipc_name": port_args.metrics_ipc_name,
"tokenizer_worker_ipc_name": port_args.tokenizer_worker_ipc_name,
}
def deserialize_data(port_args: dict, server_args: dict):
"""Deserialize data from shared dictionaries"""
return PortArgs(**port_args), ServerArgs(**server_args)
def serialize_server_args(server_args: ServerArgs) -> dict:
"""Serialize ServerArgs into a shareable dictionary"""
return dataclasses.asdict(server_args)
def serialize_scheduler_info(scheduler_info: Dict) -> dict:
"""Serialize scheduler_info into a shareable dictionary"""
return scheduler_info
def deserialize_scheduler_info(data: dict) -> Dict:
"""Deserialize scheduler_info from a shared dictionary"""
return data
def write_to_shared_memory(data: dict, name: str) -> shared_memory.SharedMemory:
"""Write data to shared memory"""
serialized = json.dumps(data).encode("utf-8")
size = len(serialized)
try:
# Try to open existing shared memory
shm = shared_memory.SharedMemory(name=name)
# If size is insufficient, close and recreate
if shm.size < size:
shm.close()
shm.unlink()
shm = shared_memory.SharedMemory(create=True, size=size, name=name)
except FileNotFoundError:
# If not present, create new shared memory
shm = shared_memory.SharedMemory(create=True, size=size, name=name)
shm.buf[:size] = serialized
return shm
def read_from_shared_memory(name: str) -> dict:
"""Read data from shared memory"""
try:
shm = shared_memory.SharedMemory(name=name)
data = json.loads(bytes(shm.buf).decode("utf-8"))
shm.close()
return data
except FileNotFoundError:
raise FileNotFoundError(f"Shared memory {name} not found")
def get_main_process_id() -> int:
"""Get the main process ID"""
return multiprocessing.current_process()._parent_pid
def write_data_for_multi_tokenizer(
port_args: PortArgs, server_args: ServerArgs, scheduler_info: Dict
):
"""Write args information to share memory for multi-tokenizer"""
# get main process ID
main_pid = get_main_process_id()
current_pid = os.getpid()
logger.info(f"main process ID: {main_pid}, current process ID: {current_pid}")
# Write port_args to shared memory
port_args_shm = write_to_shared_memory(
serialize_port_args(port_args), f"port_args_{current_pid}"
)
# Write server_args to shared memory
server_args_shm = write_to_shared_memory(
serialize_server_args(server_args), f"server_args_{current_pid}"
)
# Write scheduler_info to shared memory
scheduler_info_shm = write_to_shared_memory(
serialize_scheduler_info(scheduler_info), f"scheduler_info_{current_pid}"
)
port_args_shm.close()
server_args_shm.close()
scheduler_info_shm.close()
return port_args_shm, server_args_shm, scheduler_info_shm
......@@ -84,6 +84,8 @@ from sglang.srt.managers.io_struct import (
InitWeightsUpdateGroupReqInput,
LoadLoRAAdapterReqInput,
LoadLoRAAdapterReqOutput,
MultiTokenizerRegisterReq,
MultiTokenizerWarpper,
OpenSessionReqInput,
OpenSessionReqOutput,
ProfileReq,
......@@ -257,7 +259,6 @@ class Scheduler(
# Init inter-process communication
context = zmq.Context(2)
self.idle_sleeper = None
if self.pp_rank == 0 and self.attn_tp_rank == 0:
self.recv_from_tokenizer = get_zmq_socket(
context, zmq.PULL, port_args.scheduler_input_ipc_name, False
......@@ -540,6 +541,7 @@ class Scheduler(
(ExpertDistributionReq, self.expert_distribution_handle),
(LoadLoRAAdapterReqInput, self.load_lora_adapter),
(UnloadLoRAAdapterReqInput, self.unload_lora_adapter),
(MultiTokenizerRegisterReq, self.register_multi_tokenizer),
]
)
......@@ -1101,6 +1103,17 @@ class Scheduler(
)
self.send_to_tokenizer.send_pyobj(abort_req)
continue
# If it is a MultiTokenizerWarpper, unwrap it and handle the inner request.
if isinstance(recv_req, MultiTokenizerWarpper):
worker_id = recv_req.worker_id
recv_req = recv_req.obj
output = self._request_dispatcher(recv_req)
if output is not None:
output = MultiTokenizerWarpper(worker_id, output)
self.send_to_tokenizer.send_pyobj(output)
continue
output = self._request_dispatcher(recv_req)
if output is not None:
if isinstance(output, RpcReqOutput):
......@@ -2474,6 +2487,10 @@ class Scheduler(
result = self.tp_worker.unload_lora_adapter(recv_req)
return result
def register_multi_tokenizer(self, recv_req: MultiTokenizerRegisterReq):
self.send_to_detokenizer.send_pyobj(recv_req)
return recv_req
def slow_down(self, recv_req: SlowDownReqInput):
t = recv_req.forward_sleep_time
if t is not None and t <= 0:
......
......@@ -94,6 +94,7 @@ from sglang.srt.managers.io_struct import (
LoadLoRAAdapterReqInput,
LoadLoRAAdapterReqOutput,
LoRAUpdateResult,
MultiTokenizerWarpper,
OpenSessionReqInput,
OpenSessionReqOutput,
ProfileReq,
......@@ -131,6 +132,7 @@ from sglang.srt.utils import (
dataclass_to_string_truncated,
freeze_gc,
get_bool_env_var,
get_origin_rid,
get_zmq_socket,
kill_process_tree,
)
......@@ -266,9 +268,15 @@ class TokenizerManager:
self.recv_from_detokenizer = get_zmq_socket(
context, zmq.PULL, port_args.tokenizer_ipc_name, True
)
self.send_to_scheduler = get_zmq_socket(
context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
)
if self.server_args.tokenizer_worker_num > 1:
# Use tokenizer_worker_ipc_name in multi-tokenizer mode
self.send_to_scheduler = get_zmq_socket(
context, zmq.PUSH, port_args.tokenizer_worker_ipc_name, False
)
else:
self.send_to_scheduler = get_zmq_socket(
context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
)
# Request states
self.no_create_loop = False
......@@ -312,35 +320,7 @@ class TokenizerManager:
self.lora_update_lock = asyncio.Lock()
# For PD disaggregtion
self.disaggregation_mode = DisaggregationMode(
self.server_args.disaggregation_mode
)
self.disaggregation_transfer_backend = TransferBackend(
self.server_args.disaggregation_transfer_backend
)
# Start kv boostrap server on prefill
if self.disaggregation_mode == DisaggregationMode.PREFILL:
# only start bootstrap server on prefill tm
kv_bootstrap_server_class = get_kv_class(
self.disaggregation_transfer_backend, KVClassType.BOOTSTRAP_SERVER
)
self.bootstrap_server = kv_bootstrap_server_class(
self.server_args.disaggregation_bootstrap_port
)
is_create_store = (
self.server_args.node_rank == 0
and self.server_args.disaggregation_transfer_backend == "ascend"
)
if is_create_store:
try:
from mf_adapter import create_config_store
ascend_url = os.getenv("ASCEND_MF_STORE_URL")
create_config_store(ascend_url)
except Exception as e:
error_message = f"Failed create mf store, invalid ascend_url."
error_message += f" With exception {e}"
raise error_message
self.init_disaggregation()
# For load balancing
self.current_load = 0
......@@ -488,6 +468,37 @@ class TokenizerManager:
]
)
def init_disaggregation(self):
self.disaggregation_mode = DisaggregationMode(
self.server_args.disaggregation_mode
)
self.disaggregation_transfer_backend = TransferBackend(
self.server_args.disaggregation_transfer_backend
)
# Start kv boostrap server on prefill
if self.disaggregation_mode == DisaggregationMode.PREFILL:
# only start bootstrap server on prefill tm
kv_bootstrap_server_class = get_kv_class(
self.disaggregation_transfer_backend, KVClassType.BOOTSTRAP_SERVER
)
self.bootstrap_server = kv_bootstrap_server_class(
self.server_args.disaggregation_bootstrap_port
)
is_create_store = (
self.server_args.node_rank == 0
and self.server_args.disaggregation_transfer_backend == "ascend"
)
if is_create_store:
try:
from mf_adapter import create_config_store
ascend_url = os.getenv("ASCEND_MF_STORE_URL")
create_config_store(ascend_url)
except Exception as e:
error_message = f"Failed create mf store, invalid ascend_url."
error_message += f" With exception {e}"
raise error_message
async def generate_request(
self,
obj: Union[GenerateReqInput, EmbeddingReqInput],
......@@ -497,6 +508,15 @@ class TokenizerManager:
self.auto_create_handle_loop()
obj.normalize_batch_and_arguments()
if self.server_args.tokenizer_worker_num > 1:
# Modify rid, add worker_id
if isinstance(obj.rid, list):
# If it's an array, add worker_id prefix to each element
obj.rid = [f"{self.worker_id}_{rid}" for rid in obj.rid]
else:
# If it's a single value, add worker_id prefix
obj.rid = f"{self.worker_id}_{obj.rid}"
if self.log_requests:
max_length, skip_names, _ = self.log_request_metadata
logger.info(
......@@ -1096,6 +1116,8 @@ class TokenizerManager:
async def _wait_for_model_update_from_disk(
self, obj: UpdateWeightFromDiskReqInput
) -> Tuple[bool, str]:
if self.server_args.tokenizer_worker_num > 1:
obj = MultiTokenizerWarpper(self.worker_id, obj)
self.send_to_scheduler.send_pyobj(obj)
self.model_update_result = asyncio.Future()
if self.server_args.dp_size == 1:
......@@ -1315,6 +1337,8 @@ class TokenizerManager:
elif obj.session_id in self.session_futures:
return None
if self.server_args.tokenizer_worker_num > 1:
obj = MultiTokenizerWarpper(self.worker_id, obj)
self.send_to_scheduler.send_pyobj(obj)
self.session_futures[obj.session_id] = asyncio.Future()
......@@ -1590,7 +1614,6 @@ class TokenizerManager:
async def handle_loop(self):
"""The event loop that handles requests"""
while True:
recv_obj = await self.recv_from_detokenizer.recv_pyobj()
self._result_dispatcher(recv_obj)
......@@ -1610,9 +1633,12 @@ class TokenizerManager:
)
continue
origin_rid = rid
if self.server_args.tokenizer_worker_num > 1:
origin_rid = get_origin_rid(rid)
# Build meta_info and return value
meta_info = {
"id": rid,
"id": origin_rid,
"finish_reason": recv_obj.finished_reasons[i],
"prompt_tokens": recv_obj.prompt_tokens[i],
"weight_version": self.server_args.weight_version,
......@@ -1918,6 +1944,9 @@ class TokenizerManager:
if is_health_check_generate_req(recv_obj):
return
state = self.rid_to_state[recv_obj.rid]
origin_rid = recv_obj.rid
if self.server_args.tokenizer_worker_num > 1:
origin_rid = get_origin_rid(origin_rid)
state.finished = True
if recv_obj.finished_reason:
out = {
......@@ -1930,7 +1959,7 @@ class TokenizerManager:
out = {
"text": "",
"meta_info": {
"id": recv_obj.rid,
"id": origin_rid,
"finish_reason": {
"type": "abort",
"message": "Abort before prefill",
......@@ -2116,6 +2145,8 @@ T = TypeVar("T")
class _Communicator(Generic[T]):
"""Note: The communicator now only run up to 1 in-flight request at any time."""
enable_multi_tokenizer = False
def __init__(self, sender, fan_out: int):
self._sender = sender
self._fan_out = fan_out
......@@ -2132,6 +2163,8 @@ class _Communicator(Generic[T]):
assert self._result_values is None
if obj:
if _Communicator.enable_multi_tokenizer:
obj = MultiTokenizerWarpper(worker_id=os.getpid(), obj=obj)
self._sender.send_pyobj(obj)
self._result_event = asyncio.Event()
......
......@@ -128,6 +128,7 @@ class ServerArgs:
model_path: str
tokenizer_path: Optional[str] = None
tokenizer_mode: str = "auto"
tokenizer_worker_num: int = 1
skip_tokenizer_init: bool = False
load_format: str = "auto"
model_loader_extra_config: str = "{}"
......@@ -827,6 +828,12 @@ class ServerArgs:
default=ServerArgs.tokenizer_path,
help="The path of the tokenizer.",
)
parser.add_argument(
"--tokenizer-worker-num",
type=int,
default=ServerArgs.tokenizer_worker_num,
help="The worker num of the tokenizer manager.",
)
parser.add_argument(
"--tokenizer-mode",
type=str,
......@@ -2176,6 +2183,9 @@ class ServerArgs:
self.chunked_prefill_size % self.page_size == 0
), "chunked_prefill_size must be divisible by page_size"
# Check multi tokenizer
assert self.tokenizer_worker_num > 0, "Tokenizer worker num must >= 1"
def check_lora_server_args(self):
assert self.max_loras_per_batch > 0, "max_loras_per_batch must be positive"
......@@ -2419,6 +2429,9 @@ class PortArgs:
# The ipc filename for Scheduler to send metrics
metrics_ipc_name: str
# The ipc filename for Tokenizer and worker tokenizer
tokenizer_worker_ipc_name: Optional[str]
@staticmethod
def init_new(server_args, dp_rank: Optional[int] = None) -> "PortArgs":
if server_args.nccl_port is None:
......@@ -2442,6 +2455,7 @@ class PortArgs:
nccl_port=nccl_port,
rpc_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
metrics_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
tokenizer_worker_ipc_name=None,
)
else:
# DP attention. Use TCP + port to handle both single-node and multi-node.
......@@ -2475,6 +2489,7 @@ class PortArgs:
nccl_port=nccl_port,
rpc_ipc_name=f"tcp://{dist_init_host}:{rpc_port}",
metrics_ipc_name=f"tcp://{dist_init_host}:{metrics_ipc_name}",
tokenizer_worker_ipc_name=None,
)
......
......@@ -2787,6 +2787,20 @@ def lru_cache_frozenset(maxsize=128):
return decorator
def get_worker_ids_from_req_rids(rids):
if isinstance(rids, list):
worker_ids = [int(rid.split("_")[0]) for rid in rids]
elif isinstance(rids, str):
worker_ids = [int(rids.split("_")[0])]
else:
worker_ids = []
return worker_ids
def get_origin_rid(rid):
return rid.split("_", 1)[1] if "_" in rid else rid
def apply_module_patch(target_module, target_function, wrappers):
original_module, original_function = parse_module_path(
target_module, target_function, False
......
......@@ -85,6 +85,7 @@ suites = {
TestFile("test_mla_int8_deepseek_v3.py", 429),
TestFile("test_mla_flashinfer.py", 302),
TestFile("test_mla_fp8.py", 93),
TestFile("test_multi_tokenizer.py", 230),
TestFile("test_no_chunked_prefill.py", 108),
TestFile("test_no_overlap_scheduler.py", 234),
TestFile("test_original_logprobs.py", 200),
......
import unittest
from types import SimpleNamespace
import sglang.srt.managers.io_struct as io_struct
from sglang.srt.utils import kill_process_tree
from sglang.test.run_eval import run_eval
from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
auto_config_device,
get_benchmark_args,
is_in_ci,
popen_launch_server,
run_benchmark,
write_github_step_summary,
)
class TestMultiTokenizer(CustomTestCase):
# from test_hicache.py
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--tokenizer-worker-num",
8,
"--mem-fraction-static",
0.7,
],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_mmlu(self):
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mmlu",
num_examples=64,
num_threads=32,
)
metrics = run_eval(args)
self.assertGreaterEqual(metrics["score"], 0.65)
def test_multi_tokenizer_ttft(self):
# from test_bench_serving.py run_bench_serving
args = get_benchmark_args(
base_url=self.base_url,
dataset_name="random",
dataset_path="",
tokenizer=None,
num_prompts=100,
random_input_len=4096,
random_output_len=2048,
sharegpt_context_len=None,
request_rate=1,
disable_stream=False,
disable_ignore_eos=False,
seed=0,
device=auto_config_device(),
lora_name=None,
)
res = run_benchmark(args)
if is_in_ci():
write_github_step_summary(
f"### test_multi_tokenizer_ttft\n"
f"median_e2e_latency_ms: {res['median_e2e_latency_ms']:.2f} ms\n"
)
self.assertLess(res["median_e2e_latency_ms"], 11000)
self.assertLess(res["median_ttft_ms"], 86)
self.assertLess(res["median_itl_ms"], 10)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment