Unverified Commit 7490e3f6 authored by ybyang's avatar ybyang Committed by GitHub
Browse files

Support Multi Process Tokenizer Manager (#6555)


Signed-off-by: default avatarybyang <ybyang7@iflytek.com>
Signed-off-by: default avatarhuanglong <huanglong@linux.alibaba.com>
Co-authored-by: default avatarlw9527 <952799980@qq.com>
Co-authored-by: default avatarhuanglong <huanglong@linux.alibaba.com>
Co-authored-by: default avatarHuang Long <121648372+LLLL114@users.noreply.github.com>
parent 6ee6619b
...@@ -18,14 +18,18 @@ This file implements HTTP APIs for the inference engine via fastapi. ...@@ -18,14 +18,18 @@ This file implements HTTP APIs for the inference engine via fastapi.
""" """
import asyncio import asyncio
import ctypes
import dataclasses import dataclasses
import json import json
import logging import logging
import multiprocessing as multiprocessing import multiprocessing as multiprocessing
import os import os
import sys
import tempfile
import threading import threading
import time import time
from http import HTTPStatus from http import HTTPStatus
from multiprocessing import Lock, Manager, Value, shared_memory
from typing import AsyncIterator, Callable, Dict, Optional from typing import AsyncIterator, Callable, Dict, Optional
# Fix a bug of Python threading # Fix a bug of Python threading
...@@ -94,7 +98,7 @@ from sglang.srt.managers.template_manager import TemplateManager ...@@ -94,7 +98,7 @@ from sglang.srt.managers.template_manager import TemplateManager
from sglang.srt.managers.tokenizer_manager import ServerStatus, TokenizerManager from sglang.srt.managers.tokenizer_manager import ServerStatus, TokenizerManager
from sglang.srt.metrics.func_timer import enable_func_timer from sglang.srt.metrics.func_timer import enable_func_timer
from sglang.srt.reasoning_parser import ReasoningParser from sglang.srt.reasoning_parser import ReasoningParser
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import ( from sglang.srt.utils import (
add_api_key_middleware, add_api_key_middleware,
add_prometheus_middleware, add_prometheus_middleware,
...@@ -129,8 +133,165 @@ def set_global_state(global_state: _GlobalState): ...@@ -129,8 +133,165 @@ def set_global_state(global_state: _GlobalState):
_global_state = global_state _global_state = global_state
def serialize_port_args(port_args: PortArgs) -> dict:
"""Serialize PortArgs into a shareable dictionary"""
return {
"tokenizer_ipc_name": port_args.tokenizer_ipc_name,
"scheduler_input_ipc_name": port_args.scheduler_input_ipc_name,
"detokenizer_ipc_name": port_args.detokenizer_ipc_name,
"nccl_port": port_args.nccl_port,
"rpc_ipc_name": port_args.rpc_ipc_name,
"metrics_ipc_name": port_args.metrics_ipc_name,
"tokenizer_worker_ipc_name": port_args.tokenizer_worker_ipc_name,
}
def deserialize_port_args(data: dict) -> PortArgs:
"""Deserialize PortArgs from a shared dictionary"""
return PortArgs(**data)
def serialize_server_args(server_args: ServerArgs) -> dict:
"""Serialize ServerArgs into a shareable dictionary"""
return dataclasses.asdict(server_args)
def deserialize_server_args(data: dict) -> ServerArgs:
"""Deserialize ServerArgs from a shared dictionary"""
return ServerArgs(**data)
def serialize_scheduler_info(scheduler_info: Dict) -> dict:
"""Serialize scheduler_info into a shareable dictionary"""
return scheduler_info
def deserialize_scheduler_info(data: dict) -> Dict:
"""Deserialize scheduler_info from a shared dictionary"""
return data
def write_to_shared_memory(data: dict, name: str) -> shared_memory.SharedMemory:
"""Write data to shared memory"""
serialized = json.dumps(data).encode("utf-8")
size = len(serialized)
try:
# Try to open existing shared memory
shm = shared_memory.SharedMemory(name=name)
# If size is insufficient, close and recreate
if shm.size < size:
shm.close()
shm.unlink()
shm = shared_memory.SharedMemory(create=True, size=size, name=name)
except FileNotFoundError:
# If not present, create new shared memory
shm = shared_memory.SharedMemory(create=True, size=size, name=name)
shm.buf[:size] = serialized
return shm
def read_from_shared_memory(name: str) -> dict:
"""Read data from shared memory"""
try:
shm = shared_memory.SharedMemory(name=name)
data = json.loads(bytes(shm.buf).decode("utf-8"))
shm.close()
return data
except FileNotFoundError:
raise FileNotFoundError(f"Shared memory {name} not found")
def get_main_process_id() -> int:
"""Get the main process ID"""
return multiprocessing.current_process()._parent_pid
def write_data_for_multi_tokenizer(
port_args: PortArgs, server_args: ServerArgs, scheduler_info: Dict
):
"""Write args information to share memory for multi-tokenizer"""
# get main process ID
main_pid = get_main_process_id()
current_pid = os.getpid()
logger.info(f"main process ID: {main_pid}, current process ID: {current_pid}")
# Write port_args to shared memory
port_args_shm = write_to_shared_memory(
serialize_port_args(port_args), f"port_args_{current_pid}"
)
# Write server_args to shared memory
server_args_shm = write_to_shared_memory(
serialize_server_args(server_args), f"server_args_{current_pid}"
)
# Write scheduler_info to shared memory
scheduler_info_shm = write_to_shared_memory(
serialize_scheduler_info(scheduler_info), f"scheduler_info_{current_pid}"
)
port_args_shm.close()
server_args_shm.close()
scheduler_info_shm.close()
return port_args_shm, server_args_shm, scheduler_info_shm
def init_multi_tokenizer() -> ServerArgs:
"""Read args information from shm and init tokenizer manager for current process"""
pid = os.getpid()
main_pid = get_main_process_id()
logger.info(f"current worker_id: {pid}, main processID: {main_pid}")
# Read port_args, server_args, and scheduler_info from shared memory
port_args_data = read_from_shared_memory(f"port_args_{main_pid}")
server_args_data = read_from_shared_memory(f"server_args_{main_pid}")
scheduler_info_data = read_from_shared_memory(f"scheduler_info_{main_pid}")
port_args = deserialize_port_args(port_args_data)
server_args = deserialize_server_args(server_args_data)
scheduler_info = deserialize_scheduler_info(scheduler_info_data)
port_args.tokenizer_ipc_name = (
f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}"
)
# Launch tokenizer process
tokenizer_manager = TokenizerManager(server_args, port_args, False)
template_manager = TemplateManager()
template_manager.initialize_templates(
tokenizer_manager=tokenizer_manager,
model_path=server_args.model_path,
chat_template=server_args.chat_template,
completion_template=server_args.completion_template,
)
# register multi tokenizer
tokenizer_manager.register_to_main_tokenizer_manager()
tokenizer_manager.max_req_input_len = scheduler_info["max_req_input_len"]
set_global_state(
_GlobalState(
tokenizer_manager=tokenizer_manager,
template_manager=template_manager,
scheduler_info=scheduler_info,
)
)
return server_args
@asynccontextmanager @asynccontextmanager
async def lifespan(fast_api_app: FastAPI): async def lifespan(fast_api_app: FastAPI):
server_args = getattr(fast_api_app, "server_args", None)
if server_args is None:
# for multi-tokenizer
fast_api_app.server_args = init_multi_tokenizer()
fast_api_app.warmup_thread = threading.Thread(
target=_wait_and_warmup,
args=(
fast_api_app.server_args,
None, # pipe_finish_writer not needed in worker
None, # launch_callback not needed in worker
),
)
# Initialize OpenAI serving handlers # Initialize OpenAI serving handlers
fast_api_app.state.openai_serving_completion = OpenAIServingCompletion( fast_api_app.state.openai_serving_completion = OpenAIServingCompletion(
_global_state.tokenizer_manager, _global_state.template_manager _global_state.tokenizer_manager, _global_state.template_manager
...@@ -191,7 +352,15 @@ async def lifespan(fast_api_app: FastAPI): ...@@ -191,7 +352,15 @@ async def lifespan(fast_api_app: FastAPI):
warmup_thread = getattr(fast_api_app, "warmup_thread", None) warmup_thread = getattr(fast_api_app, "warmup_thread", None)
if warmup_thread is not None: if warmup_thread is not None:
warmup_thread.start() warmup_thread.start()
try:
yield yield
finally:
if server_args.tokenizer_worker_num > 1:
pid = os.getpid()
logger.info(f"uvicorn worker {pid} ending...")
warmup_thread.join()
logger.info(f"uvicorn {pid} ended")
# Fast API # Fast API
...@@ -208,6 +377,30 @@ app.add_middleware( ...@@ -208,6 +377,30 @@ app.add_middleware(
) )
# Function to setup all middlewares for multi-process compatibility
def setup_middlewares():
"""Setup all middlewares for both single and multi-process modes"""
worker_pid = os.getpid()
# Setup API key middleware
api_key = os.environ.get("SGLANG_API_KEY", "")
if api_key:
add_api_key_middleware(app, api_key)
logger.info(f"Worker {worker_pid} added API key middleware")
# Setup prometheus middleware
# Check if metrics are enabled via environment variable
enable_metrics = get_bool_env_var("SGLANG_ENABLE_METRICS", "false")
if enable_metrics:
add_prometheus_middleware(app)
enable_func_timer()
logger.info(f"Worker {worker_pid} added prometheus middleware")
# Call setup function at module level for multi-process compatibility
setup_middlewares()
@app.exception_handler(HTTPException) @app.exception_handler(HTTPException)
async def validation_exception_handler(request: Request, exc: HTTPException): async def validation_exception_handler(request: Request, exc: HTTPException):
"""Enrich HTTP exception with status code and other details""" """Enrich HTTP exception with status code and other details"""
...@@ -993,9 +1186,19 @@ def launch_server( ...@@ -993,9 +1186,19 @@ def launch_server(
1. The HTTP server, Engine, and TokenizerManager both run in the main process. 1. The HTTP server, Engine, and TokenizerManager both run in the main process.
2. Inter-process communication is done through IPC (each process uses a different port) via the ZMQ library. 2. Inter-process communication is done through IPC (each process uses a different port) via the ZMQ library.
""" """
if server_args.tokenizer_worker_num > 1:
port_args = PortArgs.init_new(server_args)
port_args.tokenizer_worker_ipc_name = (
f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}"
)
tokenizer_manager, template_manager, scheduler_info = _launch_subprocesses( tokenizer_manager, template_manager, scheduler_info = _launch_subprocesses(
server_args=server_args server_args=server_args, port_args=port_args
) )
else:
tokenizer_manager, template_manager, scheduler_info = _launch_subprocesses(
server_args=server_args,
)
set_global_state( set_global_state(
_GlobalState( _GlobalState(
tokenizer_manager=tokenizer_manager, tokenizer_manager=tokenizer_manager,
...@@ -1004,6 +1207,24 @@ def launch_server( ...@@ -1004,6 +1207,24 @@ def launch_server(
) )
) )
if server_args.tokenizer_worker_num > 1:
# Set environment variables for middlewares in main process
if server_args.api_key:
os.environ["SGLANG_API_KEY"] = server_args.api_key
logger.info("Main process set SGLANG_API_KEY")
if server_args.enable_metrics:
os.environ["SGLANG_ENABLE_METRICS"] = "true"
logger.info("Main process set SGLANG_ENABLE_METRICS=true")
port_args_shm, server_args_shm, scheduler_info_shm = (
write_data_for_multi_tokenizer(
port_args,
server_args,
scheduler_info,
)
)
else:
# Add api key authorization # Add api key authorization
if server_args.api_key: if server_args.api_key:
add_api_key_middleware(app, server_args.api_key) add_api_key_middleware(app, server_args.api_key)
...@@ -1030,6 +1251,24 @@ def launch_server( ...@@ -1030,6 +1251,24 @@ def launch_server(
set_uvicorn_logging_configs() set_uvicorn_logging_configs()
app.server_args = server_args app.server_args = server_args
# Listen for HTTP requests # Listen for HTTP requests
if server_args.tokenizer_worker_num > 1:
from uvicorn.config import LOGGING_CONFIG
LOGGING_CONFIG["loggers"]["sglang.srt.entrypoints.http_server"] = {
"handlers": ["default"],
"level": "INFO",
"propagate": False,
}
uvicorn.run(
"sglang.srt.entrypoints.http_server:app",
host=server_args.host,
port=server_args.port,
log_level=server_args.log_level_http or server_args.log_level,
timeout_keep_alive=5,
loop="uvloop",
workers=server_args.tokenizer_worker_num,
)
else:
uvicorn.run( uvicorn.run(
app, app,
host=server_args.host, host=server_args.host,
...@@ -1039,6 +1278,11 @@ def launch_server( ...@@ -1039,6 +1278,11 @@ def launch_server(
loop="uvloop", loop="uvloop",
) )
finally: finally:
if server_args.tokenizer_worker_num > 1:
port_args_shm.unlink()
server_args_shm.unlink()
scheduler_info_shm.unlink()
else:
warmup_thread.join() warmup_thread.join()
......
...@@ -31,10 +31,12 @@ from sglang.srt.managers.io_struct import ( ...@@ -31,10 +31,12 @@ from sglang.srt.managers.io_struct import (
BatchMultimodalOut, BatchMultimodalOut,
BatchStrOut, BatchStrOut,
BatchTokenIDOut, BatchTokenIDOut,
MultiTokenizerRegisterReq,
) )
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import ( from sglang.srt.utils import (
configure_logger, configure_logger,
get_workerids_from_rids,
get_zmq_socket, get_zmq_socket,
kill_itself_when_parent_died, kill_itself_when_parent_died,
) )
...@@ -81,7 +83,6 @@ class DetokenizerManager: ...@@ -81,7 +83,6 @@ class DetokenizerManager:
self.send_to_tokenizer = get_zmq_socket( self.send_to_tokenizer = get_zmq_socket(
context, zmq.PUSH, port_args.tokenizer_ipc_name, False context, zmq.PUSH, port_args.tokenizer_ipc_name, False
) )
if server_args.skip_tokenizer_init: if server_args.skip_tokenizer_init:
self.tokenizer = None self.tokenizer = None
else: else:
...@@ -94,21 +95,208 @@ class DetokenizerManager: ...@@ -94,21 +95,208 @@ class DetokenizerManager:
self.decode_status = LimitedCapacityDict(capacity=DETOKENIZER_MAX_STATES) self.decode_status = LimitedCapacityDict(capacity=DETOKENIZER_MAX_STATES)
self.is_dummy = server_args.load_format == "dummy" self.is_dummy = server_args.load_format == "dummy"
self.tokenizer_worker_num = server_args.tokenizer_worker_num
self._request_dispatcher = TypeBasedDispatcher( self._request_dispatcher = TypeBasedDispatcher(
[ [
(BatchEmbeddingOut, self.handle_batch_embedding_out), (BatchEmbeddingOut, self.handle_batch_embedding_out),
(BatchTokenIDOut, self.handle_batch_token_id_out), (BatchTokenIDOut, self.handle_batch_token_id_out),
(BatchMultimodalDecodeReq, self.handle_multimodal_decode_req), (BatchMultimodalDecodeReq, self.handle_multimodal_decode_req),
(MultiTokenizerRegisterReq, lambda x: None),
] ]
) )
def event_loop(self): def event_loop(self):
"""The event loop that handles requests""" """The event loop that handles requests"""
while True: while True:
try:
recv_obj = self.recv_from_scheduler.recv_pyobj() recv_obj = self.recv_from_scheduler.recv_pyobj()
output = self._request_dispatcher(recv_obj) output = self._request_dispatcher(recv_obj)
if self.tokenizer_worker_num <= 1:
self.send_to_tokenizer.send_pyobj(output) self.send_to_tokenizer.send_pyobj(output)
else:
# Extract worker_id from rid
if isinstance(recv_obj.rids, list):
worker_ids = get_workerids_from_rids(recv_obj.rids)
else:
raise RuntimeError(
f"tokenizer_worker_num > 1, recv_obj.rids must be list"
)
if not hasattr(self, "tokenizer_mapping"):
self.tokenizer_mapping = {}
# Create ZMQ context if needed
if not hasattr(self, "_zmq_context"):
self._zmq_context = zmq.Context()
# Send data using the corresponding socket
for i, worker_id in enumerate(worker_ids):
if worker_id not in self.tokenizer_mapping:
# register the worker if not already done
if isinstance(recv_obj, MultiTokenizerRegisterReq):
self.init_tokenizer_mapping(recv_obj, worker_id)
else:
logger.error(
f"Worker {worker_id} not registered and not found in tokenizer mapping . "
"Please ensure the worker is registered correctly."
)
continue
else:
if isinstance(recv_obj, MultiTokenizerRegisterReq):
continue
# Create a new output object based on the type
if isinstance(output, BatchEmbeddingOut):
new_output = BatchEmbeddingOut(
rids=[output.rids[i]],
finished_reasons=[output.finished_reasons[i]],
embeddings=[output.embeddings[i]],
prompt_tokens=[output.prompt_tokens[i]],
cached_tokens=[output.cached_tokens[i]],
)
elif isinstance(output, BatchStrOut):
new_output = BatchStrOut(
rids=[output.rids[i]],
finished_reasons=(
[output.finished_reasons[i]]
if len(output.finished_reasons) > i
else None
),
output_strs=(
[output.output_strs[i]]
if len(output.output_strs) > i
else None
),
output_ids=(
[output.output_ids[i]]
if output.output_ids and len(output.output_ids) > i
else None
),
prompt_tokens=(
[output.prompt_tokens[i]]
if len(output.prompt_tokens) > i
else None
),
completion_tokens=(
[output.completion_tokens[i]]
if len(output.completion_tokens) > i
else None
),
cached_tokens=(
[output.cached_tokens[i]]
if len(output.cached_tokens) > i
else None
),
spec_verify_ct=(
[output.spec_verify_ct[i]]
if len(output.spec_verify_ct) > i
else None
),
input_token_logprobs_val=(
[output.input_token_logprobs_val[i]]
if output.input_token_logprobs_val
else None
),
input_token_logprobs_idx=(
[output.input_token_logprobs_idx[i]]
if output.input_token_logprobs_idx
else None
),
output_token_logprobs_val=(
[output.output_token_logprobs_val[i]]
if output.output_token_logprobs_val
else None
),
output_token_logprobs_idx=(
[output.output_token_logprobs_idx[i]]
if output.output_token_logprobs_idx
else None
),
input_top_logprobs_val=(
[output.input_top_logprobs_val[i]]
if output.input_top_logprobs_val
else None
),
input_top_logprobs_idx=(
[output.input_top_logprobs_idx[i]]
if output.input_top_logprobs_idx
else None
),
output_top_logprobs_val=(
[output.output_top_logprobs_val[i]]
if output.output_top_logprobs_val
else None
),
output_top_logprobs_idx=(
[output.output_top_logprobs_idx[i]]
if output.output_top_logprobs_idx
else None
),
input_token_ids_logprobs_val=(
[output.input_token_ids_logprobs_val[i]]
if output.input_token_ids_logprobs_val
else None
),
input_token_ids_logprobs_idx=(
[output.input_token_ids_logprobs_idx[i]]
if output.input_token_ids_logprobs_idx
else None
),
output_token_ids_logprobs_val=(
[output.output_token_ids_logprobs_val[i]]
if output.output_token_ids_logprobs_val
else None
),
output_token_ids_logprobs_idx=(
[output.output_token_ids_logprobs_idx[i]]
if output.output_token_ids_logprobs_idx
else None
),
output_hidden_states=(
[output.output_hidden_states[i]]
if output.output_hidden_states
else None
),
)
elif isinstance(output, BatchMultimodalOut):
new_output = BatchMultimodalOut(
rids=[output.rids[i]],
finished_reasons=[output.finished_reasons[i]],
prompt_tokens=[output.prompt_tokens[i]],
completion_tokens=[output.completion_tokens[i]],
cached_tokens=[output.cached_tokens[i]],
)
else:
new_output = output
try:
self.tokenizer_mapping[worker_id].send_pyobj(new_output)
except zmq.error.ZMQError as e:
logger.info(
f"ZMQ error when sending to worker {worker_id}: {e}"
)
except Exception as e:
logger.error(f"Error in detokenizer event loop: {e}")
raise e
def init_tokenizer_mapping(
self, recv_obj: MultiTokenizerRegisterReq, worker_id: str
):
"""init tokenizer mapping from register request"""
ipc_name = recv_obj.ipc_name
worker_id_int = int(worker_id)
if worker_id_int not in self.tokenizer_mapping:
socket = get_zmq_socket(self._zmq_context, zmq.PUSH, ipc_name, False)
self.tokenizer_mapping[worker_id_int] = socket
logger.info(
f"Detokenizer Manager Created ZMQ socket for worker {worker_id} with ipc_name {ipc_name}"
)
else:
logger.info(
f"ZMQ socket for worker {worker_id} already exists, skipping creation"
)
def trim_matched_stop( def trim_matched_stop(
self, output: Union[str, List[int]], finished_reason: Dict, no_stop_trim: bool self, output: Union[str, List[int]], finished_reason: Dict, no_stop_trim: bool
......
...@@ -782,12 +782,13 @@ class BatchEmbeddingOut: ...@@ -782,12 +782,13 @@ class BatchEmbeddingOut:
@dataclass @dataclass
class FlushCacheReqInput: class FlushCacheReqInput:
pass rids: Optional[Union[List[str], str]] = None
@dataclass @dataclass
class FlushCacheReqOutput: class FlushCacheReqOutput:
success: bool success: bool
rids: Optional[Union[List[str], str]] = None
@dataclass @dataclass
...@@ -798,6 +799,7 @@ class UpdateWeightFromDiskReqInput: ...@@ -798,6 +799,7 @@ class UpdateWeightFromDiskReqInput:
load_format: Optional[str] = None load_format: Optional[str] = None
# Whether to abort all requests before updating weights # Whether to abort all requests before updating weights
abort_all_requests: bool = False abort_all_requests: bool = False
rids: Optional[Union[List[str], str]] = None
@dataclass @dataclass
...@@ -806,6 +808,7 @@ class UpdateWeightFromDiskReqOutput: ...@@ -806,6 +808,7 @@ class UpdateWeightFromDiskReqOutput:
message: str message: str
# Number of paused requests during weight sync. # Number of paused requests during weight sync.
num_paused_requests: Optional[int] = 0 num_paused_requests: Optional[int] = 0
rids: Optional[Union[List[str], str]] = None
@dataclass @dataclass
...@@ -819,12 +822,14 @@ class UpdateWeightsFromDistributedReqInput: ...@@ -819,12 +822,14 @@ class UpdateWeightsFromDistributedReqInput:
flush_cache: bool = True flush_cache: bool = True
# Whether to abort all requests before updating weights # Whether to abort all requests before updating weights
abort_all_requests: bool = False abort_all_requests: bool = False
rids: Optional[Union[List[str], str]] = None
@dataclass @dataclass
class UpdateWeightsFromDistributedReqOutput: class UpdateWeightsFromDistributedReqOutput:
success: bool success: bool
message: str message: str
rids: Optional[Union[List[str], str]] = None
@dataclass @dataclass
...@@ -842,12 +847,14 @@ class UpdateWeightsFromTensorReqInput: ...@@ -842,12 +847,14 @@ class UpdateWeightsFromTensorReqInput:
flush_cache: bool = True flush_cache: bool = True
# Whether to abort all requests before updating weights # Whether to abort all requests before updating weights
abort_all_requests: bool = False abort_all_requests: bool = False
rids: Optional[Union[List[str], str]] = None
@dataclass @dataclass
class UpdateWeightsFromTensorReqOutput: class UpdateWeightsFromTensorReqOutput:
success: bool success: bool
message: str message: str
rids: Optional[Union[List[str], str]] = None
@dataclass @dataclass
...@@ -864,23 +871,27 @@ class InitWeightsUpdateGroupReqInput: ...@@ -864,23 +871,27 @@ class InitWeightsUpdateGroupReqInput:
group_name: str = "weight_update_group" group_name: str = "weight_update_group"
# The backend # The backend
backend: str = "nccl" backend: str = "nccl"
rids: Optional[Union[List[str], str]] = None
@dataclass @dataclass
class InitWeightsUpdateGroupReqOutput: class InitWeightsUpdateGroupReqOutput:
success: bool success: bool
message: str message: str
rids: Optional[Union[List[str], str]] = None
@dataclass @dataclass
class GetWeightsByNameReqInput: class GetWeightsByNameReqInput:
name: str name: str
truncate_size: int = 100 truncate_size: int = 100
rids: Optional[Union[List[str], str]] = None
@dataclass @dataclass
class GetWeightsByNameReqOutput: class GetWeightsByNameReqOutput:
parameter: list parameter: list
rids: Optional[Union[List[str], str]] = None
@dataclass @dataclass
...@@ -888,11 +899,12 @@ class ReleaseMemoryOccupationReqInput: ...@@ -888,11 +899,12 @@ class ReleaseMemoryOccupationReqInput:
# Optional tags to identify the memory region, which is primarily used for RL # Optional tags to identify the memory region, which is primarily used for RL
# Currently we only support `weights` and `kv_cache` # Currently we only support `weights` and `kv_cache`
tags: Optional[List[str]] = None tags: Optional[List[str]] = None
rids: Optional[Union[List[str], str]] = None
@dataclass @dataclass
class ReleaseMemoryOccupationReqOutput: class ReleaseMemoryOccupationReqOutput:
pass rids: Optional[Union[List[str], str]] = None
@dataclass @dataclass
...@@ -900,21 +912,23 @@ class ResumeMemoryOccupationReqInput: ...@@ -900,21 +912,23 @@ class ResumeMemoryOccupationReqInput:
# Optional tags to identify the memory region, which is primarily used for RL # Optional tags to identify the memory region, which is primarily used for RL
# Currently we only support `weights` and `kv_cache` # Currently we only support `weights` and `kv_cache`
tags: Optional[List[str]] = None tags: Optional[List[str]] = None
rids: Optional[Union[List[str], str]] = None
@dataclass @dataclass
class ResumeMemoryOccupationReqOutput: class ResumeMemoryOccupationReqOutput:
pass rids: Optional[Union[List[str], str]] = None
@dataclass @dataclass
class SlowDownReqInput: class SlowDownReqInput:
forward_sleep_time: Optional[float] forward_sleep_time: Optional[float]
rids: Optional[Union[List[str], str]] = None
@dataclass @dataclass
class SlowDownReqOutput: class SlowDownReqOutput:
pass rids: Optional[Union[List[str], str]] = None
@dataclass @dataclass
...@@ -923,29 +937,37 @@ class AbortReq: ...@@ -923,29 +937,37 @@ class AbortReq:
rid: str = "" rid: str = ""
# Whether to abort all requests # Whether to abort all requests
abort_all: bool = False abort_all: bool = False
# The finished reason data
rids: Optional[Union[List[str], str]] = None
finished_reason: Optional[Dict[str, Any]] = None finished_reason: Optional[Dict[str, Any]] = None
def __post_init__(self):
self.rids = self.rid
@dataclass @dataclass
class GetInternalStateReq: class GetInternalStateReq:
pass rids: Optional[Union[List[str], str]] = None
@dataclass @dataclass
class GetInternalStateReqOutput: class GetInternalStateReqOutput:
internal_state: Dict[Any, Any] internal_state: Dict[Any, Any]
rids: Optional[Union[List[str], str]] = None
@dataclass @dataclass
class SetInternalStateReq: class SetInternalStateReq:
server_args: Dict[str, Any] server_args: Dict[str, Any]
rids: Optional[Union[List[str], str]] = None
@dataclass @dataclass
class SetInternalStateReqOutput: class SetInternalStateReqOutput:
updated: bool updated: bool
server_args: Dict[str, Any] server_args: Dict[str, Any]
rids: Optional[Union[List[str], str]] = None
@dataclass @dataclass
...@@ -961,6 +983,7 @@ class ProfileReqInput: ...@@ -961,6 +983,7 @@ class ProfileReqInput:
profile_by_stage: bool = False profile_by_stage: bool = False
with_stack: Optional[bool] = None with_stack: Optional[bool] = None
record_shapes: Optional[bool] = None record_shapes: Optional[bool] = None
rids: Optional[Union[List[str], str]] = None
class ProfileReqType(Enum): class ProfileReqType(Enum):
...@@ -979,12 +1002,14 @@ class ProfileReq: ...@@ -979,12 +1002,14 @@ class ProfileReq:
with_stack: Optional[bool] = None with_stack: Optional[bool] = None
record_shapes: Optional[bool] = None record_shapes: Optional[bool] = None
profile_id: Optional[str] = None profile_id: Optional[str] = None
rids: Optional[Union[List[str], str]] = None
@dataclass @dataclass
class ProfileReqOutput: class ProfileReqOutput:
success: bool success: bool
message: str message: str
rids: Optional[Union[List[str], str]] = None
@dataclass @dataclass
...@@ -993,27 +1018,32 @@ class ConfigureLoggingReq: ...@@ -993,27 +1018,32 @@ class ConfigureLoggingReq:
log_requests_level: Optional[int] = None log_requests_level: Optional[int] = None
dump_requests_folder: Optional[str] = None dump_requests_folder: Optional[str] = None
dump_requests_threshold: Optional[int] = None dump_requests_threshold: Optional[int] = None
rids: Optional[Union[List[str], str]] = None
@dataclass @dataclass
class OpenSessionReqInput: class OpenSessionReqInput:
capacity_of_str_len: int capacity_of_str_len: int
session_id: Optional[str] = None session_id: Optional[str] = None
rids: Optional[Union[List[str], str]] = None
@dataclass @dataclass
class CloseSessionReqInput: class CloseSessionReqInput:
session_id: str session_id: str
rids: Optional[Union[List[str], str]] = None
@dataclass @dataclass
class OpenSessionReqOutput: class OpenSessionReqOutput:
session_id: Optional[str] session_id: Optional[str]
success: bool success: bool
rids: Optional[Union[List[str], str]] = None
@dataclass @dataclass
class HealthCheckOutput: class HealthCheckOutput:
rids: Optional[Union[List[str], str]] = None
pass pass
...@@ -1025,7 +1055,7 @@ class ExpertDistributionReq(Enum): ...@@ -1025,7 +1055,7 @@ class ExpertDistributionReq(Enum):
@dataclass @dataclass
class ExpertDistributionReqOutput: class ExpertDistributionReqOutput:
pass rids: Optional[Union[List[str], str]] = None
@dataclass @dataclass
...@@ -1050,18 +1080,21 @@ class ParseFunctionCallReq: ...@@ -1050,18 +1080,21 @@ class ParseFunctionCallReq:
tool_call_parser: Optional[str] = ( tool_call_parser: Optional[str] = (
None # Specify the parser type, e.g. 'llama3', 'qwen25', or 'mistral'. If not specified, tries all. None # Specify the parser type, e.g. 'llama3', 'qwen25', or 'mistral'. If not specified, tries all.
) )
rids: Optional[Union[List[str], str]] = None
@dataclass @dataclass
class SeparateReasoningReqInput: class SeparateReasoningReqInput:
text: str # The text to parse. text: str # The text to parse.
reasoning_parser: str # Specify the parser type, e.g., "deepseek-r1". reasoning_parser: str # Specify the parser type, e.g., "deepseek-r1".
rids: Optional[Union[List[str], str]] = None
@dataclass @dataclass
class VertexGenerateReqInput: class VertexGenerateReqInput:
instances: List[dict] instances: List[dict]
parameters: Optional[dict] = None parameters: Optional[dict] = None
rids: Optional[Union[List[str], str]] = None
@dataclass @dataclass
...@@ -1086,6 +1119,7 @@ class LoadLoRAAdapterReqInput: ...@@ -1086,6 +1119,7 @@ class LoadLoRAAdapterReqInput:
pinned: bool = False pinned: bool = False
# The unique identifier for the LoRA adapter, which automatically generated in the `TokenizerManager`. # The unique identifier for the LoRA adapter, which automatically generated in the `TokenizerManager`.
lora_id: Optional[str] = None lora_id: Optional[str] = None
rids: Optional[Union[List[str], str]] = None
def to_ref(self) -> LoRARef: def to_ref(self) -> LoRARef:
return LoRARef( return LoRARef(
...@@ -1102,6 +1136,7 @@ class UnloadLoRAAdapterReqInput: ...@@ -1102,6 +1136,7 @@ class UnloadLoRAAdapterReqInput:
lora_name: str lora_name: str
# The unique identifier for the LoRA adapter, which automatically generated in the `TokenizerManager`. # The unique identifier for the LoRA adapter, which automatically generated in the `TokenizerManager`.
lora_id: Optional[str] = None lora_id: Optional[str] = None
rids: Optional[Union[List[str], str]] = None
def to_ref(self) -> LoRARef: def to_ref(self) -> LoRARef:
return LoRARef( return LoRARef(
...@@ -1115,11 +1150,18 @@ class LoRAUpdateResult: ...@@ -1115,11 +1150,18 @@ class LoRAUpdateResult:
success: bool success: bool
error_message: Optional[str] = None error_message: Optional[str] = None
loaded_adapters: Optional[Dict[str, LoRARef]] = None loaded_adapters: Optional[Dict[str, LoRARef]] = None
rids: Optional[Union[List[str], str]] = None
LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult
@dataclass
class MultiTokenizerRegisterReq:
rids: Optional[Union[List[str], str]] = None
ipc_name: Optional[str] = None
class BlockReqType(Enum): class BlockReqType(Enum):
BLOCK = 1 BLOCK = 1
UNBLOCK = 2 UNBLOCK = 2
......
...@@ -79,6 +79,7 @@ from sglang.srt.managers.io_struct import ( ...@@ -79,6 +79,7 @@ from sglang.srt.managers.io_struct import (
InitWeightsUpdateGroupReqInput, InitWeightsUpdateGroupReqInput,
LoadLoRAAdapterReqInput, LoadLoRAAdapterReqInput,
LoadLoRAAdapterReqOutput, LoadLoRAAdapterReqOutput,
MultiTokenizerRegisterReq,
OpenSessionReqInput, OpenSessionReqInput,
OpenSessionReqOutput, OpenSessionReqOutput,
ProfileReq, ProfileReq,
...@@ -247,7 +248,6 @@ class Scheduler( ...@@ -247,7 +248,6 @@ class Scheduler(
# Init inter-process communication # Init inter-process communication
context = zmq.Context(2) context = zmq.Context(2)
self.idle_sleeper = None self.idle_sleeper = None
if self.pp_rank == 0 and self.attn_tp_rank == 0: if self.pp_rank == 0 and self.attn_tp_rank == 0:
self.recv_from_tokenizer = get_zmq_socket( self.recv_from_tokenizer = get_zmq_socket(
context, zmq.PULL, port_args.scheduler_input_ipc_name, False context, zmq.PULL, port_args.scheduler_input_ipc_name, False
...@@ -522,6 +522,7 @@ class Scheduler( ...@@ -522,6 +522,7 @@ class Scheduler(
(ExpertDistributionReq, self.expert_distribution_handle), (ExpertDistributionReq, self.expert_distribution_handle),
(LoadLoRAAdapterReqInput, self.load_lora_adapter), (LoadLoRAAdapterReqInput, self.load_lora_adapter),
(UnloadLoRAAdapterReqInput, self.unload_lora_adapter), (UnloadLoRAAdapterReqInput, self.unload_lora_adapter),
(MultiTokenizerRegisterReq, self.register_multi_tokenizer),
] ]
) )
...@@ -1063,6 +1064,8 @@ class Scheduler( ...@@ -1063,6 +1064,8 @@ class Scheduler(
if self.recv_from_rpc is not None: if self.recv_from_rpc is not None:
self.recv_from_rpc.send_pyobj(output) self.recv_from_rpc.send_pyobj(output)
else: else:
if recv_req.rids is not None:
output.rids = recv_req.rids
self.send_to_tokenizer.send_pyobj(output) self.send_to_tokenizer.send_pyobj(output)
def handle_generate_request( def handle_generate_request(
...@@ -2400,6 +2403,10 @@ class Scheduler( ...@@ -2400,6 +2403,10 @@ class Scheduler(
result = self.tp_worker.unload_lora_adapter(recv_req) result = self.tp_worker.unload_lora_adapter(recv_req)
return result return result
def register_multi_tokenizer(self, recv_req: MultiTokenizerRegisterReq):
self.send_to_detokenizer.send_pyobj(recv_req)
return recv_req
def slow_down(self, recv_req: SlowDownReqInput): def slow_down(self, recv_req: SlowDownReqInput):
t = recv_req.forward_sleep_time t = recv_req.forward_sleep_time
if t is not None and t <= 0: if t is not None and t <= 0:
......
...@@ -51,6 +51,7 @@ class ServerArgs: ...@@ -51,6 +51,7 @@ class ServerArgs:
model_path: str model_path: str
tokenizer_path: Optional[str] = None tokenizer_path: Optional[str] = None
tokenizer_mode: str = "auto" tokenizer_mode: str = "auto"
tokenizer_worker_num: int = 1
skip_tokenizer_init: bool = False skip_tokenizer_init: bool = False
load_format: str = "auto" load_format: str = "auto"
model_loader_extra_config: str = "{}" model_loader_extra_config: str = "{}"
...@@ -730,6 +731,12 @@ class ServerArgs: ...@@ -730,6 +731,12 @@ class ServerArgs:
default=ServerArgs.tokenizer_path, default=ServerArgs.tokenizer_path,
help="The path of the tokenizer.", help="The path of the tokenizer.",
) )
parser.add_argument(
"--tokenizer-worker-num",
type=int,
default=ServerArgs.tokenizer_worker_num,
help="The worker num of the tokenizer manager.",
)
parser.add_argument( parser.add_argument(
"--tokenizer-mode", "--tokenizer-mode",
type=str, type=str,
...@@ -2081,6 +2088,9 @@ class ServerArgs: ...@@ -2081,6 +2088,9 @@ class ServerArgs:
self.chunked_prefill_size % self.page_size == 0 self.chunked_prefill_size % self.page_size == 0
), "chunked_prefill_size must be divisible by page_size" ), "chunked_prefill_size must be divisible by page_size"
# Check multi tokenizer
assert self.tokenizer_worker_num > 0, "Tokenizer worker num must >= 1"
def check_lora_server_args(self): def check_lora_server_args(self):
assert ( assert (
self.max_loras_per_batch > 0 self.max_loras_per_batch > 0
...@@ -2246,6 +2256,9 @@ class PortArgs: ...@@ -2246,6 +2256,9 @@ class PortArgs:
# The ipc filename for Scheduler to send metrics # The ipc filename for Scheduler to send metrics
metrics_ipc_name: str metrics_ipc_name: str
# The ipc filename for Tokenizer and worker tokenizer
tokenizer_worker_ipc_name: Optional[str]
@staticmethod @staticmethod
def init_new(server_args, dp_rank: Optional[int] = None) -> "PortArgs": def init_new(server_args, dp_rank: Optional[int] = None) -> "PortArgs":
if server_args.nccl_port is None: if server_args.nccl_port is None:
...@@ -2269,6 +2282,7 @@ class PortArgs: ...@@ -2269,6 +2282,7 @@ class PortArgs:
nccl_port=nccl_port, nccl_port=nccl_port,
rpc_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}", rpc_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
metrics_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}", metrics_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
tokenizer_worker_ipc_name=None,
) )
else: else:
# DP attention. Use TCP + port to handle both single-node and multi-node. # DP attention. Use TCP + port to handle both single-node and multi-node.
...@@ -2302,6 +2316,7 @@ class PortArgs: ...@@ -2302,6 +2316,7 @@ class PortArgs:
nccl_port=nccl_port, nccl_port=nccl_port,
rpc_ipc_name=f"tcp://{dist_init_host}:{rpc_port}", rpc_ipc_name=f"tcp://{dist_init_host}:{rpc_port}",
metrics_ipc_name=f"tcp://{dist_init_host}:{metrics_ipc_name}", metrics_ipc_name=f"tcp://{dist_init_host}:{metrics_ipc_name}",
tokenizer_worker_ipc_name=None,
) )
......
...@@ -2754,6 +2754,20 @@ def lru_cache_frozenset(maxsize=128): ...@@ -2754,6 +2754,20 @@ def lru_cache_frozenset(maxsize=128):
return decorator return decorator
def get_workerids_from_rids(rids):
if isinstance(rids, list):
worker_ids = [int(rid.split("_")[0]) for rid in rids]
elif isinstance(rids, str):
worker_ids = [int(rids.split("_")[0])]
else:
worker_ids = []
return worker_ids
def get_origin_rid(rid):
return rid.split("_", 1)[1] if "_" in rid else rid
def apply_module_patch(target_module, target_function, wrappers): def apply_module_patch(target_module, target_function, wrappers):
original_module, original_function = parse_module_path( original_module, original_function = parse_module_path(
target_module, target_function, False target_module, target_function, False
......
...@@ -78,6 +78,7 @@ suites = { ...@@ -78,6 +78,7 @@ suites = {
TestFile("test_mla_int8_deepseek_v3.py", 429), TestFile("test_mla_int8_deepseek_v3.py", 429),
TestFile("test_mla_flashinfer.py", 302), TestFile("test_mla_flashinfer.py", 302),
TestFile("test_mla_fp8.py", 93), TestFile("test_mla_fp8.py", 93),
TestFile("test_multi_tokenizer.py", 200),
TestFile("test_no_chunked_prefill.py", 108), TestFile("test_no_chunked_prefill.py", 108),
TestFile("test_no_overlap_scheduler.py", 234), TestFile("test_no_overlap_scheduler.py", 234),
TestFile("test_penalty.py", 41), TestFile("test_penalty.py", 41),
......
import inspect
import unittest
from dataclasses import fields, is_dataclass
from types import SimpleNamespace
import sglang.srt.managers.io_struct as io_struct
from sglang.srt.utils import kill_process_tree
from sglang.test.run_eval import run_eval
from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
auto_config_device,
get_benchmark_args,
is_in_ci,
popen_launch_server,
run_benchmark,
write_github_step_summary,
)
class TestMultiTokenizer(CustomTestCase):
# from test_hicache.py
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--tokenizer-worker-num",
8,
"--mem-fraction-static",
0.7,
],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_mmlu(self):
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mmlu",
num_examples=64,
num_threads=32,
)
metrics = run_eval(args)
self.assertGreaterEqual(metrics["score"], 0.65)
def test_all_io_struct(self):
print("check all req types in io_struct.py")
result = []
for name, obj in inspect.getmembers(io_struct):
if inspect.isclass(obj) and is_dataclass(obj):
field_names = [f.name for f in fields(obj)]
if "rids" in field_names or "rid" in field_names:
continue
result.append(name)
print(f"WARNING:Some Request types in io_struct.py have no rids: {result}")
print(
"If a special request type can't work, check the rids field which is needed for multi-tokenizer."
)
def test_multi_tokenizer_ttft(self):
# from test_bench_serving.py run_bench_serving
args = get_benchmark_args(
base_url=self.base_url,
dataset_name="random",
dataset_path="",
tokenizer=None,
num_prompts=100,
random_input_len=4096,
random_output_len=2048,
sharegpt_context_len=None,
request_rate=1,
disable_stream=False,
disable_ignore_eos=False,
seed=0,
device=auto_config_device(),
lora_name=None,
)
res = run_benchmark(args)
if is_in_ci():
write_github_step_summary(
f"### test_multi_tokenizer_ttft\n"
f"median_e2e_latency_ms: {res['median_e2e_latency_ms']:.2f} ms\n"
)
self.assertLess(res["median_e2e_latency_ms"], 11000)
self.assertLess(res["median_ttft_ms"], 86)
self.assertLess(res["median_itl_ms"], 10)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment