Unverified Commit a9471542 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Revert "Support Multi Process Tokenizer Manager" (#8960)

parent 41357e51
......@@ -18,18 +18,14 @@ This file implements HTTP APIs for the inference engine via fastapi.
"""
import asyncio
import ctypes
import dataclasses
import json
import logging
import multiprocessing as multiprocessing
import os
import sys
import tempfile
import threading
import time
from http import HTTPStatus
from multiprocessing import Lock, Manager, Value, shared_memory
from typing import AsyncIterator, Callable, Dict, Optional
# Fix a bug of Python threading
......@@ -98,7 +94,7 @@ from sglang.srt.managers.template_manager import TemplateManager
from sglang.srt.managers.tokenizer_manager import ServerStatus, TokenizerManager
from sglang.srt.metrics.func_timer import enable_func_timer
from sglang.srt.reasoning_parser import ReasoningParser
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import (
add_api_key_middleware,
add_prometheus_middleware,
......@@ -133,165 +129,8 @@ def set_global_state(global_state: _GlobalState):
_global_state = global_state
def serialize_port_args(port_args: PortArgs) -> dict:
"""Serialize PortArgs into a shareable dictionary"""
return {
"tokenizer_ipc_name": port_args.tokenizer_ipc_name,
"scheduler_input_ipc_name": port_args.scheduler_input_ipc_name,
"detokenizer_ipc_name": port_args.detokenizer_ipc_name,
"nccl_port": port_args.nccl_port,
"rpc_ipc_name": port_args.rpc_ipc_name,
"metrics_ipc_name": port_args.metrics_ipc_name,
"tokenizer_worker_ipc_name": port_args.tokenizer_worker_ipc_name,
}
def deserialize_port_args(data: dict) -> PortArgs:
"""Deserialize PortArgs from a shared dictionary"""
return PortArgs(**data)
def serialize_server_args(server_args: ServerArgs) -> dict:
"""Serialize ServerArgs into a shareable dictionary"""
return dataclasses.asdict(server_args)
def deserialize_server_args(data: dict) -> ServerArgs:
"""Deserialize ServerArgs from a shared dictionary"""
return ServerArgs(**data)
def serialize_scheduler_info(scheduler_info: Dict) -> dict:
"""Serialize scheduler_info into a shareable dictionary"""
return scheduler_info
def deserialize_scheduler_info(data: dict) -> Dict:
"""Deserialize scheduler_info from a shared dictionary"""
return data
def write_to_shared_memory(data: dict, name: str) -> shared_memory.SharedMemory:
"""Write data to shared memory"""
serialized = json.dumps(data).encode("utf-8")
size = len(serialized)
try:
# Try to open existing shared memory
shm = shared_memory.SharedMemory(name=name)
# If size is insufficient, close and recreate
if shm.size < size:
shm.close()
shm.unlink()
shm = shared_memory.SharedMemory(create=True, size=size, name=name)
except FileNotFoundError:
# If not present, create new shared memory
shm = shared_memory.SharedMemory(create=True, size=size, name=name)
shm.buf[:size] = serialized
return shm
def read_from_shared_memory(name: str) -> dict:
"""Read data from shared memory"""
try:
shm = shared_memory.SharedMemory(name=name)
data = json.loads(bytes(shm.buf).decode("utf-8"))
shm.close()
return data
except FileNotFoundError:
raise FileNotFoundError(f"Shared memory {name} not found")
def get_main_process_id() -> int:
"""Get the main process ID"""
return multiprocessing.current_process()._parent_pid
def write_data_for_multi_tokenizer(
port_args: PortArgs, server_args: ServerArgs, scheduler_info: Dict
):
"""Write args information to share memory for multi-tokenizer"""
# get main process ID
main_pid = get_main_process_id()
current_pid = os.getpid()
logger.info(f"main process ID: {main_pid}, current process ID: {current_pid}")
# Write port_args to shared memory
port_args_shm = write_to_shared_memory(
serialize_port_args(port_args), f"port_args_{current_pid}"
)
# Write server_args to shared memory
server_args_shm = write_to_shared_memory(
serialize_server_args(server_args), f"server_args_{current_pid}"
)
# Write scheduler_info to shared memory
scheduler_info_shm = write_to_shared_memory(
serialize_scheduler_info(scheduler_info), f"scheduler_info_{current_pid}"
)
port_args_shm.close()
server_args_shm.close()
scheduler_info_shm.close()
return port_args_shm, server_args_shm, scheduler_info_shm
def init_multi_tokenizer() -> ServerArgs:
"""Read args information from shm and init tokenizer manager for current process"""
pid = os.getpid()
main_pid = get_main_process_id()
logger.info(f"current worker_id: {pid}, main processID: {main_pid}")
# Read port_args, server_args, and scheduler_info from shared memory
port_args_data = read_from_shared_memory(f"port_args_{main_pid}")
server_args_data = read_from_shared_memory(f"server_args_{main_pid}")
scheduler_info_data = read_from_shared_memory(f"scheduler_info_{main_pid}")
port_args = deserialize_port_args(port_args_data)
server_args = deserialize_server_args(server_args_data)
scheduler_info = deserialize_scheduler_info(scheduler_info_data)
port_args.tokenizer_ipc_name = (
f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}"
)
# Launch tokenizer process
tokenizer_manager = TokenizerManager(server_args, port_args, False)
template_manager = TemplateManager()
template_manager.initialize_templates(
tokenizer_manager=tokenizer_manager,
model_path=server_args.model_path,
chat_template=server_args.chat_template,
completion_template=server_args.completion_template,
)
# register multi tokenizer
tokenizer_manager.register_to_main_tokenizer_manager()
tokenizer_manager.max_req_input_len = scheduler_info["max_req_input_len"]
set_global_state(
_GlobalState(
tokenizer_manager=tokenizer_manager,
template_manager=template_manager,
scheduler_info=scheduler_info,
)
)
return server_args
@asynccontextmanager
async def lifespan(fast_api_app: FastAPI):
server_args = getattr(fast_api_app, "server_args", None)
if server_args is None:
# for multi-tokenizer
fast_api_app.server_args = init_multi_tokenizer()
fast_api_app.warmup_thread = threading.Thread(
target=_wait_and_warmup,
args=(
fast_api_app.server_args,
None, # pipe_finish_writer not needed in worker
None, # launch_callback not needed in worker
),
)
# Initialize OpenAI serving handlers
fast_api_app.state.openai_serving_completion = OpenAIServingCompletion(
_global_state.tokenizer_manager, _global_state.template_manager
......@@ -352,15 +191,7 @@ async def lifespan(fast_api_app: FastAPI):
warmup_thread = getattr(fast_api_app, "warmup_thread", None)
if warmup_thread is not None:
warmup_thread.start()
try:
yield
finally:
if server_args.tokenizer_worker_num > 1:
pid = os.getpid()
logger.info(f"uvicorn worker {pid} ending...")
warmup_thread.join()
logger.info(f"uvicorn {pid} ended")
yield
# Fast API
......@@ -377,30 +208,6 @@ app.add_middleware(
)
# Function to setup all middlewares for multi-process compatibility
def setup_middlewares():
"""Setup all middlewares for both single and multi-process modes"""
worker_pid = os.getpid()
# Setup API key middleware
api_key = os.environ.get("SGLANG_API_KEY", "")
if api_key:
add_api_key_middleware(app, api_key)
logger.info(f"Worker {worker_pid} added API key middleware")
# Setup prometheus middleware
# Check if metrics are enabled via environment variable
enable_metrics = get_bool_env_var("SGLANG_ENABLE_METRICS", "false")
if enable_metrics:
add_prometheus_middleware(app)
enable_func_timer()
logger.info(f"Worker {worker_pid} added prometheus middleware")
# Call setup function at module level for multi-process compatibility
setup_middlewares()
@app.exception_handler(HTTPException)
async def validation_exception_handler(request: Request, exc: HTTPException):
"""Enrich HTTP exception with status code and other details"""
......@@ -1186,19 +993,9 @@ def launch_server(
1. The HTTP server, Engine, and TokenizerManager both run in the main process.
2. Inter-process communication is done through IPC (each process uses a different port) via the ZMQ library.
"""
if server_args.tokenizer_worker_num > 1:
port_args = PortArgs.init_new(server_args)
port_args.tokenizer_worker_ipc_name = (
f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}"
)
tokenizer_manager, template_manager, scheduler_info = _launch_subprocesses(
server_args=server_args, port_args=port_args
)
else:
tokenizer_manager, template_manager, scheduler_info = _launch_subprocesses(
server_args=server_args,
)
tokenizer_manager, template_manager, scheduler_info = _launch_subprocesses(
server_args=server_args
)
set_global_state(
_GlobalState(
tokenizer_manager=tokenizer_manager,
......@@ -1207,83 +1004,42 @@ def launch_server(
)
)
if server_args.tokenizer_worker_num > 1:
# Set environment variables for middlewares in main process
if server_args.api_key:
os.environ["SGLANG_API_KEY"] = server_args.api_key
logger.info("Main process set SGLANG_API_KEY")
if server_args.enable_metrics:
os.environ["SGLANG_ENABLE_METRICS"] = "true"
logger.info("Main process set SGLANG_ENABLE_METRICS=true")
port_args_shm, server_args_shm, scheduler_info_shm = (
write_data_for_multi_tokenizer(
port_args,
server_args,
scheduler_info,
)
)
else:
# Add api key authorization
if server_args.api_key:
add_api_key_middleware(app, server_args.api_key)
# Add prometheus middleware
if server_args.enable_metrics:
add_prometheus_middleware(app)
enable_func_timer()
# Send a warmup request - we will create the thread launch it
# in the lifespan after all other warmups have fired.
warmup_thread = threading.Thread(
target=_wait_and_warmup,
args=(
server_args,
pipe_finish_writer,
launch_callback,
),
)
app.warmup_thread = warmup_thread
# Add api key authorization
if server_args.api_key:
add_api_key_middleware(app, server_args.api_key)
# Add prometheus middleware
if server_args.enable_metrics:
add_prometheus_middleware(app)
enable_func_timer()
# Send a warmup request - we will create the thread launch it
# in the lifespan after all other warmups have fired.
warmup_thread = threading.Thread(
target=_wait_and_warmup,
args=(
server_args,
pipe_finish_writer,
launch_callback,
),
)
app.warmup_thread = warmup_thread
try:
# Update logging configs
set_uvicorn_logging_configs()
app.server_args = server_args
# Listen for HTTP requests
if server_args.tokenizer_worker_num > 1:
from uvicorn.config import LOGGING_CONFIG
LOGGING_CONFIG["loggers"]["sglang.srt.entrypoints.http_server"] = {
"handlers": ["default"],
"level": "INFO",
"propagate": False,
}
uvicorn.run(
"sglang.srt.entrypoints.http_server:app",
host=server_args.host,
port=server_args.port,
log_level=server_args.log_level_http or server_args.log_level,
timeout_keep_alive=5,
loop="uvloop",
workers=server_args.tokenizer_worker_num,
)
else:
uvicorn.run(
app,
host=server_args.host,
port=server_args.port,
log_level=server_args.log_level_http or server_args.log_level,
timeout_keep_alive=5,
loop="uvloop",
)
uvicorn.run(
app,
host=server_args.host,
port=server_args.port,
log_level=server_args.log_level_http or server_args.log_level,
timeout_keep_alive=5,
loop="uvloop",
)
finally:
if server_args.tokenizer_worker_num > 1:
port_args_shm.unlink()
server_args_shm.unlink()
scheduler_info_shm.unlink()
else:
warmup_thread.join()
warmup_thread.join()
def _execute_server_warmup(
......
......@@ -31,12 +31,10 @@ from sglang.srt.managers.io_struct import (
BatchMultimodalOut,
BatchStrOut,
BatchTokenIDOut,
MultiTokenizerRegisterReq,
)
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import (
configure_logger,
get_workerids_from_rids,
get_zmq_socket,
kill_itself_when_parent_died,
)
......@@ -83,6 +81,7 @@ class DetokenizerManager:
self.send_to_tokenizer = get_zmq_socket(
context, zmq.PUSH, port_args.tokenizer_ipc_name, False
)
if server_args.skip_tokenizer_init:
self.tokenizer = None
else:
......@@ -95,208 +94,21 @@ class DetokenizerManager:
self.decode_status = LimitedCapacityDict(capacity=DETOKENIZER_MAX_STATES)
self.is_dummy = server_args.load_format == "dummy"
self.tokenizer_worker_num = server_args.tokenizer_worker_num
self._request_dispatcher = TypeBasedDispatcher(
[
(BatchEmbeddingOut, self.handle_batch_embedding_out),
(BatchTokenIDOut, self.handle_batch_token_id_out),
(BatchMultimodalDecodeReq, self.handle_multimodal_decode_req),
(MultiTokenizerRegisterReq, lambda x: None),
]
)
def event_loop(self):
"""The event loop that handles requests"""
while True:
try:
recv_obj = self.recv_from_scheduler.recv_pyobj()
output = self._request_dispatcher(recv_obj)
if self.tokenizer_worker_num <= 1:
self.send_to_tokenizer.send_pyobj(output)
else:
# Extract worker_id from rid
if isinstance(recv_obj.rids, list):
worker_ids = get_workerids_from_rids(recv_obj.rids)
else:
raise RuntimeError(
f"tokenizer_worker_num > 1, recv_obj.rids must be list"
)
if not hasattr(self, "tokenizer_mapping"):
self.tokenizer_mapping = {}
# Create ZMQ context if needed
if not hasattr(self, "_zmq_context"):
self._zmq_context = zmq.Context()
# Send data using the corresponding socket
for i, worker_id in enumerate(worker_ids):
if worker_id not in self.tokenizer_mapping:
# register the worker if not already done
if isinstance(recv_obj, MultiTokenizerRegisterReq):
self.init_tokenizer_mapping(recv_obj, worker_id)
else:
logger.error(
f"Worker {worker_id} not registered and not found in tokenizer mapping . "
"Please ensure the worker is registered correctly."
)
continue
else:
if isinstance(recv_obj, MultiTokenizerRegisterReq):
continue
# Create a new output object based on the type
if isinstance(output, BatchEmbeddingOut):
new_output = BatchEmbeddingOut(
rids=[output.rids[i]],
finished_reasons=[output.finished_reasons[i]],
embeddings=[output.embeddings[i]],
prompt_tokens=[output.prompt_tokens[i]],
cached_tokens=[output.cached_tokens[i]],
)
elif isinstance(output, BatchStrOut):
new_output = BatchStrOut(
rids=[output.rids[i]],
finished_reasons=(
[output.finished_reasons[i]]
if len(output.finished_reasons) > i
else None
),
output_strs=(
[output.output_strs[i]]
if len(output.output_strs) > i
else None
),
output_ids=(
[output.output_ids[i]]
if output.output_ids and len(output.output_ids) > i
else None
),
prompt_tokens=(
[output.prompt_tokens[i]]
if len(output.prompt_tokens) > i
else None
),
completion_tokens=(
[output.completion_tokens[i]]
if len(output.completion_tokens) > i
else None
),
cached_tokens=(
[output.cached_tokens[i]]
if len(output.cached_tokens) > i
else None
),
spec_verify_ct=(
[output.spec_verify_ct[i]]
if len(output.spec_verify_ct) > i
else None
),
input_token_logprobs_val=(
[output.input_token_logprobs_val[i]]
if output.input_token_logprobs_val
else None
),
input_token_logprobs_idx=(
[output.input_token_logprobs_idx[i]]
if output.input_token_logprobs_idx
else None
),
output_token_logprobs_val=(
[output.output_token_logprobs_val[i]]
if output.output_token_logprobs_val
else None
),
output_token_logprobs_idx=(
[output.output_token_logprobs_idx[i]]
if output.output_token_logprobs_idx
else None
),
input_top_logprobs_val=(
[output.input_top_logprobs_val[i]]
if output.input_top_logprobs_val
else None
),
input_top_logprobs_idx=(
[output.input_top_logprobs_idx[i]]
if output.input_top_logprobs_idx
else None
),
output_top_logprobs_val=(
[output.output_top_logprobs_val[i]]
if output.output_top_logprobs_val
else None
),
output_top_logprobs_idx=(
[output.output_top_logprobs_idx[i]]
if output.output_top_logprobs_idx
else None
),
input_token_ids_logprobs_val=(
[output.input_token_ids_logprobs_val[i]]
if output.input_token_ids_logprobs_val
else None
),
input_token_ids_logprobs_idx=(
[output.input_token_ids_logprobs_idx[i]]
if output.input_token_ids_logprobs_idx
else None
),
output_token_ids_logprobs_val=(
[output.output_token_ids_logprobs_val[i]]
if output.output_token_ids_logprobs_val
else None
),
output_token_ids_logprobs_idx=(
[output.output_token_ids_logprobs_idx[i]]
if output.output_token_ids_logprobs_idx
else None
),
output_hidden_states=(
[output.output_hidden_states[i]]
if output.output_hidden_states
else None
),
)
elif isinstance(output, BatchMultimodalOut):
new_output = BatchMultimodalOut(
rids=[output.rids[i]],
finished_reasons=[output.finished_reasons[i]],
prompt_tokens=[output.prompt_tokens[i]],
completion_tokens=[output.completion_tokens[i]],
cached_tokens=[output.cached_tokens[i]],
)
else:
new_output = output
try:
self.tokenizer_mapping[worker_id].send_pyobj(new_output)
except zmq.error.ZMQError as e:
logger.info(
f"ZMQ error when sending to worker {worker_id}: {e}"
)
except Exception as e:
logger.error(f"Error in detokenizer event loop: {e}")
raise e
def init_tokenizer_mapping(
self, recv_obj: MultiTokenizerRegisterReq, worker_id: str
):
"""init tokenizer mapping from register request"""
ipc_name = recv_obj.ipc_name
worker_id_int = int(worker_id)
if worker_id_int not in self.tokenizer_mapping:
socket = get_zmq_socket(self._zmq_context, zmq.PUSH, ipc_name, False)
self.tokenizer_mapping[worker_id_int] = socket
logger.info(
f"Detokenizer Manager Created ZMQ socket for worker {worker_id} with ipc_name {ipc_name}"
)
else:
logger.info(
f"ZMQ socket for worker {worker_id} already exists, skipping creation"
)
recv_obj = self.recv_from_scheduler.recv_pyobj()
output = self._request_dispatcher(recv_obj)
self.send_to_tokenizer.send_pyobj(output)
def trim_matched_stop(
self, output: Union[str, List[int]], finished_reason: Dict, no_stop_trim: bool
......
......@@ -782,13 +782,12 @@ class BatchEmbeddingOut:
@dataclass
class FlushCacheReqInput:
rids: Optional[Union[List[str], str]] = None
pass
@dataclass
class FlushCacheReqOutput:
success: bool
rids: Optional[Union[List[str], str]] = None
@dataclass
......@@ -799,7 +798,6 @@ class UpdateWeightFromDiskReqInput:
load_format: Optional[str] = None
# Whether to abort all requests before updating weights
abort_all_requests: bool = False
rids: Optional[Union[List[str], str]] = None
@dataclass
......@@ -808,7 +806,6 @@ class UpdateWeightFromDiskReqOutput:
message: str
# Number of paused requests during weight sync.
num_paused_requests: Optional[int] = 0
rids: Optional[Union[List[str], str]] = None
@dataclass
......@@ -822,14 +819,12 @@ class UpdateWeightsFromDistributedReqInput:
flush_cache: bool = True
# Whether to abort all requests before updating weights
abort_all_requests: bool = False
rids: Optional[Union[List[str], str]] = None
@dataclass
class UpdateWeightsFromDistributedReqOutput:
success: bool
message: str
rids: Optional[Union[List[str], str]] = None
@dataclass
......@@ -847,14 +842,12 @@ class UpdateWeightsFromTensorReqInput:
flush_cache: bool = True
# Whether to abort all requests before updating weights
abort_all_requests: bool = False
rids: Optional[Union[List[str], str]] = None
@dataclass
class UpdateWeightsFromTensorReqOutput:
success: bool
message: str
rids: Optional[Union[List[str], str]] = None
@dataclass
......@@ -871,27 +864,23 @@ class InitWeightsUpdateGroupReqInput:
group_name: str = "weight_update_group"
# The backend
backend: str = "nccl"
rids: Optional[Union[List[str], str]] = None
@dataclass
class InitWeightsUpdateGroupReqOutput:
success: bool
message: str
rids: Optional[Union[List[str], str]] = None
@dataclass
class GetWeightsByNameReqInput:
name: str
truncate_size: int = 100
rids: Optional[Union[List[str], str]] = None
@dataclass
class GetWeightsByNameReqOutput:
parameter: list
rids: Optional[Union[List[str], str]] = None
@dataclass
......@@ -899,12 +888,11 @@ class ReleaseMemoryOccupationReqInput:
# Optional tags to identify the memory region, which is primarily used for RL
# Currently we only support `weights` and `kv_cache`
tags: Optional[List[str]] = None
rids: Optional[Union[List[str], str]] = None
@dataclass
class ReleaseMemoryOccupationReqOutput:
rids: Optional[Union[List[str], str]] = None
pass
@dataclass
......@@ -912,23 +900,21 @@ class ResumeMemoryOccupationReqInput:
# Optional tags to identify the memory region, which is primarily used for RL
# Currently we only support `weights` and `kv_cache`
tags: Optional[List[str]] = None
rids: Optional[Union[List[str], str]] = None
@dataclass
class ResumeMemoryOccupationReqOutput:
rids: Optional[Union[List[str], str]] = None
pass
@dataclass
class SlowDownReqInput:
forward_sleep_time: Optional[float]
rids: Optional[Union[List[str], str]] = None
@dataclass
class SlowDownReqOutput:
rids: Optional[Union[List[str], str]] = None
pass
@dataclass
......@@ -937,37 +923,29 @@ class AbortReq:
rid: str = ""
# Whether to abort all requests
abort_all: bool = False
rids: Optional[Union[List[str], str]] = None
# The finished reason data
finished_reason: Optional[Dict[str, Any]] = None
def __post_init__(self):
self.rids = self.rid
@dataclass
class GetInternalStateReq:
rids: Optional[Union[List[str], str]] = None
pass
@dataclass
class GetInternalStateReqOutput:
internal_state: Dict[Any, Any]
rids: Optional[Union[List[str], str]] = None
@dataclass
class SetInternalStateReq:
server_args: Dict[str, Any]
rids: Optional[Union[List[str], str]] = None
@dataclass
class SetInternalStateReqOutput:
updated: bool
server_args: Dict[str, Any]
rids: Optional[Union[List[str], str]] = None
@dataclass
......@@ -983,7 +961,6 @@ class ProfileReqInput:
profile_by_stage: bool = False
with_stack: Optional[bool] = None
record_shapes: Optional[bool] = None
rids: Optional[Union[List[str], str]] = None
class ProfileReqType(Enum):
......@@ -1002,14 +979,12 @@ class ProfileReq:
with_stack: Optional[bool] = None
record_shapes: Optional[bool] = None
profile_id: Optional[str] = None
rids: Optional[Union[List[str], str]] = None
@dataclass
class ProfileReqOutput:
success: bool
message: str
rids: Optional[Union[List[str], str]] = None
@dataclass
......@@ -1018,32 +993,27 @@ class ConfigureLoggingReq:
log_requests_level: Optional[int] = None
dump_requests_folder: Optional[str] = None
dump_requests_threshold: Optional[int] = None
rids: Optional[Union[List[str], str]] = None
@dataclass
class OpenSessionReqInput:
capacity_of_str_len: int
session_id: Optional[str] = None
rids: Optional[Union[List[str], str]] = None
@dataclass
class CloseSessionReqInput:
session_id: str
rids: Optional[Union[List[str], str]] = None
@dataclass
class OpenSessionReqOutput:
session_id: Optional[str]
success: bool
rids: Optional[Union[List[str], str]] = None
@dataclass
class HealthCheckOutput:
rids: Optional[Union[List[str], str]] = None
pass
......@@ -1055,7 +1025,7 @@ class ExpertDistributionReq(Enum):
@dataclass
class ExpertDistributionReqOutput:
rids: Optional[Union[List[str], str]] = None
pass
@dataclass
......@@ -1080,21 +1050,18 @@ class ParseFunctionCallReq:
tool_call_parser: Optional[str] = (
None # Specify the parser type, e.g. 'llama3', 'qwen25', or 'mistral'. If not specified, tries all.
)
rids: Optional[Union[List[str], str]] = None
@dataclass
class SeparateReasoningReqInput:
text: str # The text to parse.
reasoning_parser: str # Specify the parser type, e.g., "deepseek-r1".
rids: Optional[Union[List[str], str]] = None
@dataclass
class VertexGenerateReqInput:
instances: List[dict]
parameters: Optional[dict] = None
rids: Optional[Union[List[str], str]] = None
@dataclass
......@@ -1119,7 +1086,6 @@ class LoadLoRAAdapterReqInput:
pinned: bool = False
# The unique identifier for the LoRA adapter, which automatically generated in the `TokenizerManager`.
lora_id: Optional[str] = None
rids: Optional[Union[List[str], str]] = None
def to_ref(self) -> LoRARef:
return LoRARef(
......@@ -1136,7 +1102,6 @@ class UnloadLoRAAdapterReqInput:
lora_name: str
# The unique identifier for the LoRA adapter, which automatically generated in the `TokenizerManager`.
lora_id: Optional[str] = None
rids: Optional[Union[List[str], str]] = None
def to_ref(self) -> LoRARef:
return LoRARef(
......@@ -1150,18 +1115,11 @@ class LoRAUpdateResult:
success: bool
error_message: Optional[str] = None
loaded_adapters: Optional[Dict[str, LoRARef]] = None
rids: Optional[Union[List[str], str]] = None
LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult
@dataclass
class MultiTokenizerRegisterReq:
rids: Optional[Union[List[str], str]] = None
ipc_name: Optional[str] = None
class BlockReqType(Enum):
BLOCK = 1
UNBLOCK = 2
......
......@@ -79,7 +79,6 @@ from sglang.srt.managers.io_struct import (
InitWeightsUpdateGroupReqInput,
LoadLoRAAdapterReqInput,
LoadLoRAAdapterReqOutput,
MultiTokenizerRegisterReq,
OpenSessionReqInput,
OpenSessionReqOutput,
ProfileReq,
......@@ -248,6 +247,7 @@ class Scheduler(
# Init inter-process communication
context = zmq.Context(2)
self.idle_sleeper = None
if self.pp_rank == 0 and self.attn_tp_rank == 0:
self.recv_from_tokenizer = get_zmq_socket(
context, zmq.PULL, port_args.scheduler_input_ipc_name, False
......@@ -522,7 +522,6 @@ class Scheduler(
(ExpertDistributionReq, self.expert_distribution_handle),
(LoadLoRAAdapterReqInput, self.load_lora_adapter),
(UnloadLoRAAdapterReqInput, self.unload_lora_adapter),
(MultiTokenizerRegisterReq, self.register_multi_tokenizer),
]
)
......@@ -1065,8 +1064,6 @@ class Scheduler(
if self.recv_from_rpc is not None:
self.recv_from_rpc.send_pyobj(output)
else:
if recv_req.rids is not None:
output.rids = recv_req.rids
self.send_to_tokenizer.send_pyobj(output)
def handle_generate_request(
......@@ -2407,10 +2404,6 @@ class Scheduler(
result = self.tp_worker.unload_lora_adapter(recv_req)
return result
def register_multi_tokenizer(self, recv_req: MultiTokenizerRegisterReq):
self.send_to_detokenizer.send_pyobj(recv_req)
return recv_req
def slow_down(self, recv_req: SlowDownReqInput):
t = recv_req.forward_sleep_time
if t is not None and t <= 0:
......
......@@ -51,7 +51,6 @@ class ServerArgs:
model_path: str
tokenizer_path: Optional[str] = None
tokenizer_mode: str = "auto"
tokenizer_worker_num: int = 1
skip_tokenizer_init: bool = False
load_format: str = "auto"
model_loader_extra_config: str = "{}"
......@@ -732,12 +731,6 @@ class ServerArgs:
default=ServerArgs.tokenizer_path,
help="The path of the tokenizer.",
)
parser.add_argument(
"--tokenizer-worker-num",
type=int,
default=ServerArgs.tokenizer_worker_num,
help="The worker num of the tokenizer manager.",
)
parser.add_argument(
"--tokenizer-mode",
type=str,
......@@ -2096,9 +2089,6 @@ class ServerArgs:
self.chunked_prefill_size % self.page_size == 0
), "chunked_prefill_size must be divisible by page_size"
# Check multi tokenizer
assert self.tokenizer_worker_num > 0, "Tokenizer worker num must >= 1"
def check_lora_server_args(self):
assert (
self.max_loras_per_batch > 0
......@@ -2264,9 +2254,6 @@ class PortArgs:
# The ipc filename for Scheduler to send metrics
metrics_ipc_name: str
# The ipc filename for Tokenizer and worker tokenizer
tokenizer_worker_ipc_name: Optional[str]
@staticmethod
def init_new(server_args, dp_rank: Optional[int] = None) -> "PortArgs":
if server_args.nccl_port is None:
......@@ -2290,7 +2277,6 @@ class PortArgs:
nccl_port=nccl_port,
rpc_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
metrics_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
tokenizer_worker_ipc_name=None,
)
else:
# DP attention. Use TCP + port to handle both single-node and multi-node.
......@@ -2324,7 +2310,6 @@ class PortArgs:
nccl_port=nccl_port,
rpc_ipc_name=f"tcp://{dist_init_host}:{rpc_port}",
metrics_ipc_name=f"tcp://{dist_init_host}:{metrics_ipc_name}",
tokenizer_worker_ipc_name=None,
)
......
......@@ -2754,20 +2754,6 @@ def lru_cache_frozenset(maxsize=128):
return decorator
def get_workerids_from_rids(rids):
if isinstance(rids, list):
worker_ids = [int(rid.split("_")[0]) for rid in rids]
elif isinstance(rids, str):
worker_ids = [int(rids.split("_")[0])]
else:
worker_ids = []
return worker_ids
def get_origin_rid(rid):
return rid.split("_", 1)[1] if "_" in rid else rid
def apply_module_patch(target_module, target_function, wrappers):
original_module, original_function = parse_module_path(
target_module, target_function, False
......
......@@ -78,7 +78,6 @@ suites = {
TestFile("test_mla_int8_deepseek_v3.py", 429),
TestFile("test_mla_flashinfer.py", 302),
TestFile("test_mla_fp8.py", 93),
TestFile("test_multi_tokenizer.py", 200),
TestFile("test_no_chunked_prefill.py", 108),
TestFile("test_no_overlap_scheduler.py", 234),
TestFile("test_penalty.py", 41),
......
import inspect
import unittest
from dataclasses import fields, is_dataclass
from types import SimpleNamespace
import sglang.srt.managers.io_struct as io_struct
from sglang.srt.utils import kill_process_tree
from sglang.test.run_eval import run_eval
from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
auto_config_device,
get_benchmark_args,
is_in_ci,
popen_launch_server,
run_benchmark,
write_github_step_summary,
)
class TestMultiTokenizer(CustomTestCase):
# from test_hicache.py
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--tokenizer-worker-num",
8,
"--mem-fraction-static",
0.7,
],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_mmlu(self):
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mmlu",
num_examples=64,
num_threads=32,
)
metrics = run_eval(args)
self.assertGreaterEqual(metrics["score"], 0.65)
def test_all_io_struct(self):
print("check all req types in io_struct.py")
result = []
for name, obj in inspect.getmembers(io_struct):
if inspect.isclass(obj) and is_dataclass(obj):
field_names = [f.name for f in fields(obj)]
if "rids" in field_names or "rid" in field_names:
continue
result.append(name)
print(f"WARNING:Some Request types in io_struct.py have no rids: {result}")
print(
"If a special request type can't work, check the rids field which is needed for multi-tokenizer."
)
def test_multi_tokenizer_ttft(self):
# from test_bench_serving.py run_bench_serving
args = get_benchmark_args(
base_url=self.base_url,
dataset_name="random",
dataset_path="",
tokenizer=None,
num_prompts=100,
random_input_len=4096,
random_output_len=2048,
sharegpt_context_len=None,
request_rate=1,
disable_stream=False,
disable_ignore_eos=False,
seed=0,
device=auto_config_device(),
lora_name=None,
)
res = run_benchmark(args)
if is_in_ci():
write_github_step_summary(
f"### test_multi_tokenizer_ttft\n"
f"median_e2e_latency_ms: {res['median_e2e_latency_ms']:.2f} ms\n"
)
self.assertLess(res["median_e2e_latency_ms"], 11000)
self.assertLess(res["median_ttft_ms"], 86)
self.assertLess(res["median_itl_ms"], 10)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment