"torchvision/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "fb115c2295b6bc2f2ed7689e1ace00f25d882777"
Unverified Commit 5f77e129 authored by ybyang's avatar ybyang Committed by GitHub
Browse files

Support Multi Process Tokenizer Manager(#6555) (#8964)


Signed-off-by: default avatarybyang <ybyang7@iflytek.com>
Signed-off-by: default avatarhuanglong <huanglong@linux.alibaba.com>
Co-authored-by: default avatarHuang Long <121648372+LLLL114@users.noreply.github.com>
Co-authored-by: default avatarhuanglong <huanglong@linux.alibaba.com>
Co-authored-by: default avatarShangming Cai <csmthu@gmail.com>
parent 4750cddf
...@@ -60,6 +60,7 @@ from sglang.srt.managers.io_struct import ( ...@@ -60,6 +60,7 @@ from sglang.srt.managers.io_struct import (
UpdateWeightsFromDistributedReqInput, UpdateWeightsFromDistributedReqInput,
UpdateWeightsFromTensorReqInput, UpdateWeightsFromTensorReqInput,
) )
from sglang.srt.managers.multi_tokenizer_mixin import MultiTokenizerRouter
from sglang.srt.managers.scheduler import run_scheduler_process from sglang.srt.managers.scheduler import run_scheduler_process
from sglang.srt.managers.template_manager import TemplateManager from sglang.srt.managers.template_manager import TemplateManager
from sglang.srt.managers.tokenizer_manager import TokenizerManager from sglang.srt.managers.tokenizer_manager import TokenizerManager
...@@ -814,18 +815,24 @@ def _launch_subprocesses( ...@@ -814,18 +815,24 @@ def _launch_subprocesses(
), ),
) )
detoken_proc.start() detoken_proc.start()
if server_args.tokenizer_worker_num > 1:
# Launch multi-tokenizer router
tokenizer_manager = MultiTokenizerRouter(server_args, port_args)
# Launch tokenizer process # Initialize templates
tokenizer_manager = TokenizerManager(server_args, port_args) template_manager = None
else:
# Launch tokenizer process
tokenizer_manager = TokenizerManager(server_args, port_args)
# Initialize templates # Initialize templates
template_manager = TemplateManager() template_manager = TemplateManager()
template_manager.initialize_templates( template_manager.initialize_templates(
tokenizer_manager=tokenizer_manager, tokenizer_manager=tokenizer_manager,
model_path=server_args.model_path, model_path=server_args.model_path,
chat_template=server_args.chat_template, chat_template=server_args.chat_template,
completion_template=server_args.completion_template, completion_template=server_args.completion_template,
) )
# Wait for the model to finish loading # Wait for the model to finish loading
scheduler_infos = [] scheduler_infos = []
......
...@@ -23,6 +23,7 @@ import json ...@@ -23,6 +23,7 @@ import json
import logging import logging
import multiprocessing as multiprocessing import multiprocessing as multiprocessing
import os import os
import tempfile
import threading import threading
import time import time
from http import HTTPStatus from http import HTTPStatus
...@@ -91,11 +92,18 @@ from sglang.srt.managers.io_struct import ( ...@@ -91,11 +92,18 @@ from sglang.srt.managers.io_struct import (
UpdateWeightVersionReqInput, UpdateWeightVersionReqInput,
VertexGenerateReqInput, VertexGenerateReqInput,
) )
from sglang.srt.managers.multi_tokenizer_mixin import (
MultiTokenizerManager,
deserialize_data,
get_main_process_id,
read_from_shared_memory,
write_data_for_multi_tokenizer,
)
from sglang.srt.managers.template_manager import TemplateManager from sglang.srt.managers.template_manager import TemplateManager
from sglang.srt.managers.tokenizer_manager import ServerStatus, TokenizerManager from sglang.srt.managers.tokenizer_manager import ServerStatus, TokenizerManager
from sglang.srt.metrics.func_timer import enable_func_timer from sglang.srt.metrics.func_timer import enable_func_timer
from sglang.srt.reasoning_parser import ReasoningParser from sglang.srt.reasoning_parser import ReasoningParser
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import ( from sglang.srt.utils import (
add_api_key_middleware, add_api_key_middleware,
add_prometheus_middleware, add_prometheus_middleware,
...@@ -130,8 +138,79 @@ def set_global_state(global_state: _GlobalState): ...@@ -130,8 +138,79 @@ def set_global_state(global_state: _GlobalState):
_global_state = global_state _global_state = global_state
# Function to set up all middlewares for multi-tokenizer compatibility
def setup_middlewares(api_key: Optional[str], enable_metrics: bool):
"""Setup all middlewares for both single and multi-process modes"""
worker_pid = os.getpid()
if api_key:
add_api_key_middleware(app, api_key)
logger.info(f"Worker {worker_pid} added API key middleware")
if enable_metrics:
add_prometheus_middleware(app)
enable_func_timer()
logger.info(f"Worker {worker_pid} added prometheus middleware")
async def init_multi_tokenizer() -> ServerArgs:
"""Read args information from shm and init tokenizer manager for current process"""
pid = os.getpid()
main_pid = get_main_process_id()
logger.info(f"current worker_id: {pid}, main processID: {main_pid}")
# Read configuration from shared memory
port_args_data = read_from_shared_memory(f"port_args_{main_pid}")
server_args_data = read_from_shared_memory(f"server_args_{main_pid}")
scheduler_info_data = read_from_shared_memory(f"scheduler_info_{main_pid}")
port_args, server_args = deserialize_data(port_args_data, server_args_data)
scheduler_info = scheduler_info_data
port_args.tokenizer_ipc_name = (
f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}"
)
# Launch multi-tokenizer manager process
tokenizer_manager = MultiTokenizerManager(server_args, port_args)
template_manager = TemplateManager()
template_manager.initialize_templates(
tokenizer_manager=tokenizer_manager,
model_path=server_args.model_path,
chat_template=server_args.chat_template,
completion_template=server_args.completion_template,
)
# Register this tokenizer with the main tokenizer manager
await tokenizer_manager.register_to_main_tokenizer_manager()
tokenizer_manager.max_req_input_len = scheduler_info["max_req_input_len"]
set_global_state(
_GlobalState(
tokenizer_manager=tokenizer_manager,
template_manager=template_manager,
scheduler_info=scheduler_info,
)
)
return server_args
@asynccontextmanager @asynccontextmanager
async def lifespan(fast_api_app: FastAPI): async def lifespan(fast_api_app: FastAPI):
server_args = getattr(fast_api_app, "server_args", None)
if server_args is None:
# Initialize multi-tokenizer support for worker processes
fast_api_app.server_args = await init_multi_tokenizer()
setup_middlewares(
fast_api_app.server_args.api_key, fast_api_app.server_args.enable_metrics
)
fast_api_app.warmup_thread = threading.Thread(
target=_wait_and_warmup,
args=(
fast_api_app.server_args,
None, # pipe_finish_writer not needed in worker
None, # launch_callback not needed in worker
),
)
# Initialize OpenAI serving handlers # Initialize OpenAI serving handlers
fast_api_app.state.openai_serving_completion = OpenAIServingCompletion( fast_api_app.state.openai_serving_completion = OpenAIServingCompletion(
_global_state.tokenizer_manager, _global_state.template_manager _global_state.tokenizer_manager, _global_state.template_manager
...@@ -191,7 +270,15 @@ async def lifespan(fast_api_app: FastAPI): ...@@ -191,7 +270,15 @@ async def lifespan(fast_api_app: FastAPI):
warmup_thread = getattr(fast_api_app, "warmup_thread", None) warmup_thread = getattr(fast_api_app, "warmup_thread", None)
if warmup_thread is not None: if warmup_thread is not None:
warmup_thread.start() warmup_thread.start()
yield
try:
yield
finally:
if server_args.tokenizer_worker_num > 1:
pid = os.getpid()
logger.info(f"uvicorn worker {pid} ending...")
warmup_thread.join()
logger.info(f"uvicorn worker {pid} ended.")
# Fast API # Fast API
...@@ -1078,9 +1165,19 @@ def launch_server( ...@@ -1078,9 +1165,19 @@ def launch_server(
1. The HTTP server, Engine, and TokenizerManager both run in the main process. 1. The HTTP server, Engine, and TokenizerManager both run in the main process.
2. Inter-process communication is done through IPC (each process uses a different port) via the ZMQ library. 2. Inter-process communication is done through IPC (each process uses a different port) via the ZMQ library.
""" """
tokenizer_manager, template_manager, scheduler_info = _launch_subprocesses( if server_args.tokenizer_worker_num > 1:
server_args=server_args port_args = PortArgs.init_new(server_args)
) port_args.tokenizer_worker_ipc_name = (
f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}"
)
tokenizer_manager, template_manager, scheduler_info = _launch_subprocesses(
server_args=server_args, port_args=port_args
)
else:
tokenizer_manager, template_manager, scheduler_info = _launch_subprocesses(
server_args=server_args,
)
set_global_state( set_global_state(
_GlobalState( _GlobalState(
tokenizer_manager=tokenizer_manager, tokenizer_manager=tokenizer_manager,
...@@ -1089,42 +1186,75 @@ def launch_server( ...@@ -1089,42 +1186,75 @@ def launch_server(
) )
) )
# Add api key authorization if server_args.tokenizer_worker_num > 1:
if server_args.api_key: port_args_shm, server_args_shm, scheduler_info_shm = (
add_api_key_middleware(app, server_args.api_key) write_data_for_multi_tokenizer(
port_args,
# Add prometheus middleware server_args,
if server_args.enable_metrics: scheduler_info,
add_prometheus_middleware(app) )
enable_func_timer() )
else:
# Send a warmup request - we will create the thread launch it # Add api key authorization
# in the lifespan after all other warmups have fired. if server_args.api_key:
warmup_thread = threading.Thread( add_api_key_middleware(app, server_args.api_key)
target=_wait_and_warmup,
args=( # Add prometheus middleware
server_args, if server_args.enable_metrics:
pipe_finish_writer, add_prometheus_middleware(app)
launch_callback, enable_func_timer()
),
) # Send a warmup request - we will create the thread launch it
app.warmup_thread = warmup_thread # in the lifespan after all other warmups have fired.
warmup_thread = threading.Thread(
target=_wait_and_warmup,
args=(
server_args,
pipe_finish_writer,
launch_callback,
),
)
app.warmup_thread = warmup_thread
try: try:
# Update logging configs # Update logging configs
set_uvicorn_logging_configs() set_uvicorn_logging_configs()
app.server_args = server_args app.server_args = server_args
# Listen for HTTP requests # Listen for HTTP requests
uvicorn.run( if server_args.tokenizer_worker_num > 1:
app, from uvicorn.config import LOGGING_CONFIG
host=server_args.host,
port=server_args.port, LOGGING_CONFIG["loggers"]["sglang.srt.entrypoints.http_server"] = {
log_level=server_args.log_level_http or server_args.log_level, "handlers": ["default"],
timeout_keep_alive=5, "level": "INFO",
loop="uvloop", "propagate": False,
) }
uvicorn.run(
"sglang.srt.entrypoints.http_server:app",
host=server_args.host,
port=server_args.port,
log_level=server_args.log_level_http or server_args.log_level,
timeout_keep_alive=5,
loop="uvloop",
workers=server_args.tokenizer_worker_num,
)
else:
uvicorn.run(
app,
host=server_args.host,
port=server_args.port,
log_level=server_args.log_level_http or server_args.log_level,
timeout_keep_alive=5,
loop="uvloop",
)
finally: finally:
warmup_thread.join() if server_args.tokenizer_worker_num > 1:
port_args_shm.unlink()
server_args_shm.unlink()
scheduler_info_shm.unlink()
_global_state.tokenizer_manager.clear_tokenizer_mapping()
else:
warmup_thread.join()
def _execute_server_warmup( def _execute_server_warmup(
......
...@@ -32,11 +32,14 @@ from sglang.srt.managers.io_struct import ( ...@@ -32,11 +32,14 @@ from sglang.srt.managers.io_struct import (
BatchStrOut, BatchStrOut,
BatchTokenIDOut, BatchTokenIDOut,
FreezeGCReq, FreezeGCReq,
MultiTokenizerRegisterReq,
) )
from sglang.srt.managers.multi_tokenizer_mixin import MultiTokenizerMixin
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import ( from sglang.srt.utils import (
configure_logger, configure_logger,
freeze_gc, freeze_gc,
get_worker_ids_from_req_rids,
get_zmq_socket, get_zmq_socket,
kill_itself_when_parent_died, kill_itself_when_parent_died,
) )
...@@ -67,7 +70,7 @@ class DecodeStatus: ...@@ -67,7 +70,7 @@ class DecodeStatus:
sent_offset: int = 0 sent_offset: int = 0
class DetokenizerManager: class DetokenizerManager(MultiTokenizerMixin):
"""DetokenizerManager is a process that detokenizes the token ids.""" """DetokenizerManager is a process that detokenizes the token ids."""
def __init__( def __init__(
...@@ -102,6 +105,7 @@ class DetokenizerManager: ...@@ -102,6 +105,7 @@ class DetokenizerManager:
(BatchEmbeddingOut, self.handle_batch_embedding_out), (BatchEmbeddingOut, self.handle_batch_embedding_out),
(BatchTokenIDOut, self.handle_batch_token_id_out), (BatchTokenIDOut, self.handle_batch_token_id_out),
(BatchMultimodalDecodeReq, self.handle_multimodal_decode_req), (BatchMultimodalDecodeReq, self.handle_multimodal_decode_req),
(MultiTokenizerRegisterReq, lambda x: x),
(FreezeGCReq, self.handle_freeze_gc_req), (FreezeGCReq, self.handle_freeze_gc_req),
] ]
) )
...@@ -116,6 +120,39 @@ class DetokenizerManager: ...@@ -116,6 +120,39 @@ class DetokenizerManager:
if output is not None: if output is not None:
self.send_to_tokenizer.send_pyobj(output) self.send_to_tokenizer.send_pyobj(output)
def multi_tokenizer_manager_event_loop(self):
"""The event loop that handles requests, for multi tokenizer manager mode only"""
self.create_sockets_mapping()
while True:
recv_obj = self.recv_from_scheduler.recv_pyobj()
output = self._request_dispatcher(recv_obj)
if output is None:
continue
# Extract worker_id from rid
if isinstance(recv_obj.rids, list):
worker_ids = get_worker_ids_from_req_rids(recv_obj.rids)
else:
raise RuntimeError(
f"for tokenizer_worker_num > 1, recv_obj.rids must be a list"
)
# Send data using the corresponding socket
for i, worker_id in enumerate(worker_ids):
if isinstance(recv_obj, MultiTokenizerRegisterReq):
if self.register_tokenizer_ipc(recv_obj, worker_id):
logger.info(
f"DetokenizerManager Created ZMQ socket for worker {worker_id}"
)
continue
else:
if worker_id not in self.tokenizer_mapping:
logger.error(
f"Tokenizer Worker ID {worker_id} not registered. Check if the server Process {worker_id} is alive"
)
continue
new_output = self._handle_output_by_index(output, i)
self.tokenizer_mapping[worker_id].send_pyobj(new_output)
def trim_matched_stop( def trim_matched_stop(
self, output: Union[str, List[int]], finished_reason: Dict, no_stop_trim: bool self, output: Union[str, List[int]], finished_reason: Dict, no_stop_trim: bool
): ):
...@@ -285,8 +322,12 @@ def run_detokenizer_process( ...@@ -285,8 +322,12 @@ def run_detokenizer_process(
try: try:
manager = DetokenizerManager(server_args, port_args) manager = DetokenizerManager(server_args, port_args)
manager.event_loop() if server_args.tokenizer_worker_num > 1:
manager.multi_tokenizer_manager_event_loop()
else:
manager.event_loop()
except Exception: except Exception:
manager.clear_tokenizer_mapping()
traceback = get_exception_traceback() traceback = get_exception_traceback()
logger.error(f"DetokenizerManager hit an exception: {traceback}") logger.error(f"DetokenizerManager hit an exception: {traceback}")
parent_process.send_signal(signal.SIGQUIT) parent_process.send_signal(signal.SIGQUIT)
...@@ -983,6 +983,11 @@ class AbortReq: ...@@ -983,6 +983,11 @@ class AbortReq:
abort_all: bool = False abort_all: bool = False
# The finished reason data # The finished reason data
finished_reason: Optional[Dict[str, Any]] = None finished_reason: Optional[Dict[str, Any]] = None
# used in MultiTokenzierManager mode
rids: Optional[Union[List[str], str]] = None
def __post_init__(self):
self.rids = self.rid
@dataclass @dataclass
...@@ -1183,6 +1188,18 @@ class LoRAUpdateResult: ...@@ -1183,6 +1188,18 @@ class LoRAUpdateResult:
LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult
@dataclass
class MultiTokenizerRegisterReq:
rids: Optional[Union[List[str], str]] = None
ipc_name: Optional[str] = None
@dataclass
class MultiTokenizerWarpper:
worker_id: int
obj: Optional[Any] = None
class BlockReqType(Enum): class BlockReqType(Enum):
BLOCK = 1 BLOCK = 1
UNBLOCK = 2 UNBLOCK = 2
......
This diff is collapsed.
...@@ -84,6 +84,8 @@ from sglang.srt.managers.io_struct import ( ...@@ -84,6 +84,8 @@ from sglang.srt.managers.io_struct import (
InitWeightsUpdateGroupReqInput, InitWeightsUpdateGroupReqInput,
LoadLoRAAdapterReqInput, LoadLoRAAdapterReqInput,
LoadLoRAAdapterReqOutput, LoadLoRAAdapterReqOutput,
MultiTokenizerRegisterReq,
MultiTokenizerWarpper,
OpenSessionReqInput, OpenSessionReqInput,
OpenSessionReqOutput, OpenSessionReqOutput,
ProfileReq, ProfileReq,
...@@ -257,7 +259,6 @@ class Scheduler( ...@@ -257,7 +259,6 @@ class Scheduler(
# Init inter-process communication # Init inter-process communication
context = zmq.Context(2) context = zmq.Context(2)
self.idle_sleeper = None self.idle_sleeper = None
if self.pp_rank == 0 and self.attn_tp_rank == 0: if self.pp_rank == 0 and self.attn_tp_rank == 0:
self.recv_from_tokenizer = get_zmq_socket( self.recv_from_tokenizer = get_zmq_socket(
context, zmq.PULL, port_args.scheduler_input_ipc_name, False context, zmq.PULL, port_args.scheduler_input_ipc_name, False
...@@ -540,6 +541,7 @@ class Scheduler( ...@@ -540,6 +541,7 @@ class Scheduler(
(ExpertDistributionReq, self.expert_distribution_handle), (ExpertDistributionReq, self.expert_distribution_handle),
(LoadLoRAAdapterReqInput, self.load_lora_adapter), (LoadLoRAAdapterReqInput, self.load_lora_adapter),
(UnloadLoRAAdapterReqInput, self.unload_lora_adapter), (UnloadLoRAAdapterReqInput, self.unload_lora_adapter),
(MultiTokenizerRegisterReq, self.register_multi_tokenizer),
] ]
) )
...@@ -1101,6 +1103,17 @@ class Scheduler( ...@@ -1101,6 +1103,17 @@ class Scheduler(
) )
self.send_to_tokenizer.send_pyobj(abort_req) self.send_to_tokenizer.send_pyobj(abort_req)
continue continue
# If it is a MultiTokenizerWarpper, unwrap it and handle the inner request.
if isinstance(recv_req, MultiTokenizerWarpper):
worker_id = recv_req.worker_id
recv_req = recv_req.obj
output = self._request_dispatcher(recv_req)
if output is not None:
output = MultiTokenizerWarpper(worker_id, output)
self.send_to_tokenizer.send_pyobj(output)
continue
output = self._request_dispatcher(recv_req) output = self._request_dispatcher(recv_req)
if output is not None: if output is not None:
if isinstance(output, RpcReqOutput): if isinstance(output, RpcReqOutput):
...@@ -2474,6 +2487,10 @@ class Scheduler( ...@@ -2474,6 +2487,10 @@ class Scheduler(
result = self.tp_worker.unload_lora_adapter(recv_req) result = self.tp_worker.unload_lora_adapter(recv_req)
return result return result
def register_multi_tokenizer(self, recv_req: MultiTokenizerRegisterReq):
self.send_to_detokenizer.send_pyobj(recv_req)
return recv_req
def slow_down(self, recv_req: SlowDownReqInput): def slow_down(self, recv_req: SlowDownReqInput):
t = recv_req.forward_sleep_time t = recv_req.forward_sleep_time
if t is not None and t <= 0: if t is not None and t <= 0:
......
...@@ -94,6 +94,7 @@ from sglang.srt.managers.io_struct import ( ...@@ -94,6 +94,7 @@ from sglang.srt.managers.io_struct import (
LoadLoRAAdapterReqInput, LoadLoRAAdapterReqInput,
LoadLoRAAdapterReqOutput, LoadLoRAAdapterReqOutput,
LoRAUpdateResult, LoRAUpdateResult,
MultiTokenizerWarpper,
OpenSessionReqInput, OpenSessionReqInput,
OpenSessionReqOutput, OpenSessionReqOutput,
ProfileReq, ProfileReq,
...@@ -131,6 +132,7 @@ from sglang.srt.utils import ( ...@@ -131,6 +132,7 @@ from sglang.srt.utils import (
dataclass_to_string_truncated, dataclass_to_string_truncated,
freeze_gc, freeze_gc,
get_bool_env_var, get_bool_env_var,
get_origin_rid,
get_zmq_socket, get_zmq_socket,
kill_process_tree, kill_process_tree,
) )
...@@ -266,9 +268,15 @@ class TokenizerManager: ...@@ -266,9 +268,15 @@ class TokenizerManager:
self.recv_from_detokenizer = get_zmq_socket( self.recv_from_detokenizer = get_zmq_socket(
context, zmq.PULL, port_args.tokenizer_ipc_name, True context, zmq.PULL, port_args.tokenizer_ipc_name, True
) )
self.send_to_scheduler = get_zmq_socket( if self.server_args.tokenizer_worker_num > 1:
context, zmq.PUSH, port_args.scheduler_input_ipc_name, True # Use tokenizer_worker_ipc_name in multi-tokenizer mode
) self.send_to_scheduler = get_zmq_socket(
context, zmq.PUSH, port_args.tokenizer_worker_ipc_name, False
)
else:
self.send_to_scheduler = get_zmq_socket(
context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
)
# Request states # Request states
self.no_create_loop = False self.no_create_loop = False
...@@ -312,35 +320,7 @@ class TokenizerManager: ...@@ -312,35 +320,7 @@ class TokenizerManager:
self.lora_update_lock = asyncio.Lock() self.lora_update_lock = asyncio.Lock()
# For PD disaggregtion # For PD disaggregtion
self.disaggregation_mode = DisaggregationMode( self.init_disaggregation()
self.server_args.disaggregation_mode
)
self.disaggregation_transfer_backend = TransferBackend(
self.server_args.disaggregation_transfer_backend
)
# Start kv boostrap server on prefill
if self.disaggregation_mode == DisaggregationMode.PREFILL:
# only start bootstrap server on prefill tm
kv_bootstrap_server_class = get_kv_class(
self.disaggregation_transfer_backend, KVClassType.BOOTSTRAP_SERVER
)
self.bootstrap_server = kv_bootstrap_server_class(
self.server_args.disaggregation_bootstrap_port
)
is_create_store = (
self.server_args.node_rank == 0
and self.server_args.disaggregation_transfer_backend == "ascend"
)
if is_create_store:
try:
from mf_adapter import create_config_store
ascend_url = os.getenv("ASCEND_MF_STORE_URL")
create_config_store(ascend_url)
except Exception as e:
error_message = f"Failed create mf store, invalid ascend_url."
error_message += f" With exception {e}"
raise error_message
# For load balancing # For load balancing
self.current_load = 0 self.current_load = 0
...@@ -488,6 +468,37 @@ class TokenizerManager: ...@@ -488,6 +468,37 @@ class TokenizerManager:
] ]
) )
def init_disaggregation(self):
self.disaggregation_mode = DisaggregationMode(
self.server_args.disaggregation_mode
)
self.disaggregation_transfer_backend = TransferBackend(
self.server_args.disaggregation_transfer_backend
)
# Start kv boostrap server on prefill
if self.disaggregation_mode == DisaggregationMode.PREFILL:
# only start bootstrap server on prefill tm
kv_bootstrap_server_class = get_kv_class(
self.disaggregation_transfer_backend, KVClassType.BOOTSTRAP_SERVER
)
self.bootstrap_server = kv_bootstrap_server_class(
self.server_args.disaggregation_bootstrap_port
)
is_create_store = (
self.server_args.node_rank == 0
and self.server_args.disaggregation_transfer_backend == "ascend"
)
if is_create_store:
try:
from mf_adapter import create_config_store
ascend_url = os.getenv("ASCEND_MF_STORE_URL")
create_config_store(ascend_url)
except Exception as e:
error_message = f"Failed create mf store, invalid ascend_url."
error_message += f" With exception {e}"
raise error_message
async def generate_request( async def generate_request(
self, self,
obj: Union[GenerateReqInput, EmbeddingReqInput], obj: Union[GenerateReqInput, EmbeddingReqInput],
...@@ -497,6 +508,15 @@ class TokenizerManager: ...@@ -497,6 +508,15 @@ class TokenizerManager:
self.auto_create_handle_loop() self.auto_create_handle_loop()
obj.normalize_batch_and_arguments() obj.normalize_batch_and_arguments()
if self.server_args.tokenizer_worker_num > 1:
# Modify rid, add worker_id
if isinstance(obj.rid, list):
# If it's an array, add worker_id prefix to each element
obj.rid = [f"{self.worker_id}_{rid}" for rid in obj.rid]
else:
# If it's a single value, add worker_id prefix
obj.rid = f"{self.worker_id}_{obj.rid}"
if self.log_requests: if self.log_requests:
max_length, skip_names, _ = self.log_request_metadata max_length, skip_names, _ = self.log_request_metadata
logger.info( logger.info(
...@@ -1096,6 +1116,8 @@ class TokenizerManager: ...@@ -1096,6 +1116,8 @@ class TokenizerManager:
async def _wait_for_model_update_from_disk( async def _wait_for_model_update_from_disk(
self, obj: UpdateWeightFromDiskReqInput self, obj: UpdateWeightFromDiskReqInput
) -> Tuple[bool, str]: ) -> Tuple[bool, str]:
if self.server_args.tokenizer_worker_num > 1:
obj = MultiTokenizerWarpper(self.worker_id, obj)
self.send_to_scheduler.send_pyobj(obj) self.send_to_scheduler.send_pyobj(obj)
self.model_update_result = asyncio.Future() self.model_update_result = asyncio.Future()
if self.server_args.dp_size == 1: if self.server_args.dp_size == 1:
...@@ -1315,6 +1337,8 @@ class TokenizerManager: ...@@ -1315,6 +1337,8 @@ class TokenizerManager:
elif obj.session_id in self.session_futures: elif obj.session_id in self.session_futures:
return None return None
if self.server_args.tokenizer_worker_num > 1:
obj = MultiTokenizerWarpper(self.worker_id, obj)
self.send_to_scheduler.send_pyobj(obj) self.send_to_scheduler.send_pyobj(obj)
self.session_futures[obj.session_id] = asyncio.Future() self.session_futures[obj.session_id] = asyncio.Future()
...@@ -1590,7 +1614,6 @@ class TokenizerManager: ...@@ -1590,7 +1614,6 @@ class TokenizerManager:
async def handle_loop(self): async def handle_loop(self):
"""The event loop that handles requests""" """The event loop that handles requests"""
while True: while True:
recv_obj = await self.recv_from_detokenizer.recv_pyobj() recv_obj = await self.recv_from_detokenizer.recv_pyobj()
self._result_dispatcher(recv_obj) self._result_dispatcher(recv_obj)
...@@ -1610,9 +1633,12 @@ class TokenizerManager: ...@@ -1610,9 +1633,12 @@ class TokenizerManager:
) )
continue continue
origin_rid = rid
if self.server_args.tokenizer_worker_num > 1:
origin_rid = get_origin_rid(rid)
# Build meta_info and return value # Build meta_info and return value
meta_info = { meta_info = {
"id": rid, "id": origin_rid,
"finish_reason": recv_obj.finished_reasons[i], "finish_reason": recv_obj.finished_reasons[i],
"prompt_tokens": recv_obj.prompt_tokens[i], "prompt_tokens": recv_obj.prompt_tokens[i],
"weight_version": self.server_args.weight_version, "weight_version": self.server_args.weight_version,
...@@ -1918,6 +1944,9 @@ class TokenizerManager: ...@@ -1918,6 +1944,9 @@ class TokenizerManager:
if is_health_check_generate_req(recv_obj): if is_health_check_generate_req(recv_obj):
return return
state = self.rid_to_state[recv_obj.rid] state = self.rid_to_state[recv_obj.rid]
origin_rid = recv_obj.rid
if self.server_args.tokenizer_worker_num > 1:
origin_rid = get_origin_rid(origin_rid)
state.finished = True state.finished = True
if recv_obj.finished_reason: if recv_obj.finished_reason:
out = { out = {
...@@ -1930,7 +1959,7 @@ class TokenizerManager: ...@@ -1930,7 +1959,7 @@ class TokenizerManager:
out = { out = {
"text": "", "text": "",
"meta_info": { "meta_info": {
"id": recv_obj.rid, "id": origin_rid,
"finish_reason": { "finish_reason": {
"type": "abort", "type": "abort",
"message": "Abort before prefill", "message": "Abort before prefill",
...@@ -2116,6 +2145,8 @@ T = TypeVar("T") ...@@ -2116,6 +2145,8 @@ T = TypeVar("T")
class _Communicator(Generic[T]): class _Communicator(Generic[T]):
"""Note: The communicator now only run up to 1 in-flight request at any time.""" """Note: The communicator now only run up to 1 in-flight request at any time."""
enable_multi_tokenizer = False
def __init__(self, sender, fan_out: int): def __init__(self, sender, fan_out: int):
self._sender = sender self._sender = sender
self._fan_out = fan_out self._fan_out = fan_out
...@@ -2132,6 +2163,8 @@ class _Communicator(Generic[T]): ...@@ -2132,6 +2163,8 @@ class _Communicator(Generic[T]):
assert self._result_values is None assert self._result_values is None
if obj: if obj:
if _Communicator.enable_multi_tokenizer:
obj = MultiTokenizerWarpper(worker_id=os.getpid(), obj=obj)
self._sender.send_pyobj(obj) self._sender.send_pyobj(obj)
self._result_event = asyncio.Event() self._result_event = asyncio.Event()
......
...@@ -128,6 +128,7 @@ class ServerArgs: ...@@ -128,6 +128,7 @@ class ServerArgs:
model_path: str model_path: str
tokenizer_path: Optional[str] = None tokenizer_path: Optional[str] = None
tokenizer_mode: str = "auto" tokenizer_mode: str = "auto"
tokenizer_worker_num: int = 1
skip_tokenizer_init: bool = False skip_tokenizer_init: bool = False
load_format: str = "auto" load_format: str = "auto"
model_loader_extra_config: str = "{}" model_loader_extra_config: str = "{}"
...@@ -827,6 +828,12 @@ class ServerArgs: ...@@ -827,6 +828,12 @@ class ServerArgs:
default=ServerArgs.tokenizer_path, default=ServerArgs.tokenizer_path,
help="The path of the tokenizer.", help="The path of the tokenizer.",
) )
parser.add_argument(
"--tokenizer-worker-num",
type=int,
default=ServerArgs.tokenizer_worker_num,
help="The worker num of the tokenizer manager.",
)
parser.add_argument( parser.add_argument(
"--tokenizer-mode", "--tokenizer-mode",
type=str, type=str,
...@@ -2176,6 +2183,9 @@ class ServerArgs: ...@@ -2176,6 +2183,9 @@ class ServerArgs:
self.chunked_prefill_size % self.page_size == 0 self.chunked_prefill_size % self.page_size == 0
), "chunked_prefill_size must be divisible by page_size" ), "chunked_prefill_size must be divisible by page_size"
# Check multi tokenizer
assert self.tokenizer_worker_num > 0, "Tokenizer worker num must >= 1"
def check_lora_server_args(self): def check_lora_server_args(self):
assert self.max_loras_per_batch > 0, "max_loras_per_batch must be positive" assert self.max_loras_per_batch > 0, "max_loras_per_batch must be positive"
...@@ -2419,6 +2429,9 @@ class PortArgs: ...@@ -2419,6 +2429,9 @@ class PortArgs:
# The ipc filename for Scheduler to send metrics # The ipc filename for Scheduler to send metrics
metrics_ipc_name: str metrics_ipc_name: str
# The ipc filename for Tokenizer and worker tokenizer
tokenizer_worker_ipc_name: Optional[str]
@staticmethod @staticmethod
def init_new(server_args, dp_rank: Optional[int] = None) -> "PortArgs": def init_new(server_args, dp_rank: Optional[int] = None) -> "PortArgs":
if server_args.nccl_port is None: if server_args.nccl_port is None:
...@@ -2442,6 +2455,7 @@ class PortArgs: ...@@ -2442,6 +2455,7 @@ class PortArgs:
nccl_port=nccl_port, nccl_port=nccl_port,
rpc_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}", rpc_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
metrics_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}", metrics_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
tokenizer_worker_ipc_name=None,
) )
else: else:
# DP attention. Use TCP + port to handle both single-node and multi-node. # DP attention. Use TCP + port to handle both single-node and multi-node.
...@@ -2475,6 +2489,7 @@ class PortArgs: ...@@ -2475,6 +2489,7 @@ class PortArgs:
nccl_port=nccl_port, nccl_port=nccl_port,
rpc_ipc_name=f"tcp://{dist_init_host}:{rpc_port}", rpc_ipc_name=f"tcp://{dist_init_host}:{rpc_port}",
metrics_ipc_name=f"tcp://{dist_init_host}:{metrics_ipc_name}", metrics_ipc_name=f"tcp://{dist_init_host}:{metrics_ipc_name}",
tokenizer_worker_ipc_name=None,
) )
......
...@@ -2787,6 +2787,20 @@ def lru_cache_frozenset(maxsize=128): ...@@ -2787,6 +2787,20 @@ def lru_cache_frozenset(maxsize=128):
return decorator return decorator
def get_worker_ids_from_req_rids(rids):
if isinstance(rids, list):
worker_ids = [int(rid.split("_")[0]) for rid in rids]
elif isinstance(rids, str):
worker_ids = [int(rids.split("_")[0])]
else:
worker_ids = []
return worker_ids
def get_origin_rid(rid):
return rid.split("_", 1)[1] if "_" in rid else rid
def apply_module_patch(target_module, target_function, wrappers): def apply_module_patch(target_module, target_function, wrappers):
original_module, original_function = parse_module_path( original_module, original_function = parse_module_path(
target_module, target_function, False target_module, target_function, False
......
...@@ -85,6 +85,7 @@ suites = { ...@@ -85,6 +85,7 @@ suites = {
TestFile("test_mla_int8_deepseek_v3.py", 429), TestFile("test_mla_int8_deepseek_v3.py", 429),
TestFile("test_mla_flashinfer.py", 302), TestFile("test_mla_flashinfer.py", 302),
TestFile("test_mla_fp8.py", 93), TestFile("test_mla_fp8.py", 93),
TestFile("test_multi_tokenizer.py", 230),
TestFile("test_no_chunked_prefill.py", 108), TestFile("test_no_chunked_prefill.py", 108),
TestFile("test_no_overlap_scheduler.py", 234), TestFile("test_no_overlap_scheduler.py", 234),
TestFile("test_original_logprobs.py", 200), TestFile("test_original_logprobs.py", 200),
......
import unittest
from types import SimpleNamespace
import sglang.srt.managers.io_struct as io_struct
from sglang.srt.utils import kill_process_tree
from sglang.test.run_eval import run_eval
from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
auto_config_device,
get_benchmark_args,
is_in_ci,
popen_launch_server,
run_benchmark,
write_github_step_summary,
)
class TestMultiTokenizer(CustomTestCase):
# from test_hicache.py
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--tokenizer-worker-num",
8,
"--mem-fraction-static",
0.7,
],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_mmlu(self):
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mmlu",
num_examples=64,
num_threads=32,
)
metrics = run_eval(args)
self.assertGreaterEqual(metrics["score"], 0.65)
def test_multi_tokenizer_ttft(self):
# from test_bench_serving.py run_bench_serving
args = get_benchmark_args(
base_url=self.base_url,
dataset_name="random",
dataset_path="",
tokenizer=None,
num_prompts=100,
random_input_len=4096,
random_output_len=2048,
sharegpt_context_len=None,
request_rate=1,
disable_stream=False,
disable_ignore_eos=False,
seed=0,
device=auto_config_device(),
lora_name=None,
)
res = run_benchmark(args)
if is_in_ci():
write_github_step_summary(
f"### test_multi_tokenizer_ttft\n"
f"median_e2e_latency_ms: {res['median_e2e_latency_ms']:.2f} ms\n"
)
self.assertLess(res["median_e2e_latency_ms"], 11000)
self.assertLess(res["median_ttft_ms"], 86)
self.assertLess(res["median_itl_ms"], 10)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment