Unverified Commit 8e70064c authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Clean up server launch code and multi tokenizer (#12132)

parent d98b81e2
...@@ -806,7 +806,7 @@ def _get_and_verify_dtype( ...@@ -806,7 +806,7 @@ def _get_and_verify_dtype(
) -> torch.dtype: ) -> torch.dtype:
# NOTE: getattr(config, "torch_dtype", torch.float32) is not correct # NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
# because config.torch_dtype can be None. # because config.torch_dtype can be None.
config_dtype = getattr(config, "torch_dtype", None) config_dtype = getattr(config, "dtype", None)
if isinstance(config_dtype, str): if isinstance(config_dtype, str):
config_dtype = _STR_DTYPE_TO_TORCH_DTYPE.get(config_dtype, None) config_dtype = _STR_DTYPE_TO_TORCH_DTYPE.get(config_dtype, None)
if config_dtype is None: if config_dtype is None:
......
...@@ -101,7 +101,7 @@ class Engine(EngineBase): ...@@ -101,7 +101,7 @@ class Engine(EngineBase):
Note: Note:
1. The HTTP server, Engine, and TokenizerManager all run in the main process. 1. The HTTP server, Engine, and TokenizerManager all run in the main process.
2. Inter-process communication (IPC) is handled via the ZMQ library, with each process using a different port. 2. Inter-process communication is done through IPC (each process uses a different port) via the ZMQ library.
""" """
def __init__(self, **kwargs): def __init__(self, **kwargs):
...@@ -109,6 +109,8 @@ class Engine(EngineBase): ...@@ -109,6 +109,8 @@ class Engine(EngineBase):
The arguments of this function is the same as `sglang/srt/server_args.py::ServerArgs`. The arguments of this function is the same as `sglang/srt/server_args.py::ServerArgs`.
Please refer to `ServerArgs` for the documentation. Please refer to `ServerArgs` for the documentation.
""" """
# Parse server_args
if "server_args" in kwargs: if "server_args" in kwargs:
# Directly load server_args # Directly load server_args
server_args = kwargs["server_args"] server_args = kwargs["server_args"]
...@@ -118,29 +120,28 @@ class Engine(EngineBase): ...@@ -118,29 +120,28 @@ class Engine(EngineBase):
# Do not print logs by default # Do not print logs by default
kwargs["log_level"] = "error" kwargs["log_level"] = "error"
server_args = ServerArgs(**kwargs) server_args = ServerArgs(**kwargs)
self.server_args = server_args
logger.info(f"{server_args=}")
# Shutdown the subprocesses automatically when the program exits # Shutdown the subprocesses automatically when the program exits
atexit.register(self.shutdown) atexit.register(self.shutdown)
# Allocate ports for inter-process communications
self.port_args = PortArgs.init_new(server_args)
logger.info(f"{server_args=}")
# Launch subprocesses # Launch subprocesses
tokenizer_manager, template_manager, scheduler_info = _launch_subprocesses( tokenizer_manager, template_manager, scheduler_info, port_args = (
server_args=server_args, _launch_subprocesses(server_args=server_args)
port_args=self.port_args,
) )
self.server_args = server_args
self.tokenizer_manager = tokenizer_manager self.tokenizer_manager = tokenizer_manager
self.template_manager = template_manager self.template_manager = template_manager
self.scheduler_info = scheduler_info self.scheduler_info = scheduler_info
self.port_args = port_args
# Initialize ZMQ sockets
context = zmq.Context(2) context = zmq.Context(2)
self.send_to_rpc = get_zmq_socket( self.send_to_rpc = get_zmq_socket(
context, zmq.DEALER, self.port_args.rpc_ipc_name, True context, zmq.DEALER, self.port_args.rpc_ipc_name, True
) )
# Enable tracing
if server_args.enable_trace: if server_args.enable_trace:
process_tracing_init(server_args.oltp_traces_endpoint, "sglang") process_tracing_init(server_args.oltp_traces_endpoint, "sglang")
if server_args.disaggregation_mode == "null": if server_args.disaggregation_mode == "null":
...@@ -672,15 +673,17 @@ def _set_envs_and_config(server_args: ServerArgs): ...@@ -672,15 +673,17 @@ def _set_envs_and_config(server_args: ServerArgs):
os.environ["NCCL_CUMEM_ENABLE"] = str(int(server_args.enable_symm_mem)) os.environ["NCCL_CUMEM_ENABLE"] = str(int(server_args.enable_symm_mem))
if not server_args.enable_symm_mem: if not server_args.enable_symm_mem:
os.environ["NCCL_NVLS_ENABLE"] = str(int(server_args.enable_nccl_nvls)) os.environ["NCCL_NVLS_ENABLE"] = str(int(server_args.enable_nccl_nvls))
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4" os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "8"
os.environ["CUDA_MODULE_LOADING"] = "AUTO" os.environ["CUDA_MODULE_LOADING"] = "AUTO"
# flashinfer uses this environment variable for various kernels from MoE to quant kernels
if os.environ.get("TRTLLM_ENABLE_PDL", "1") != "0": if os.environ.get("TRTLLM_ENABLE_PDL", "1") != "0":
# flashinfer uses this environment variable for various kernels from MoE to quant kernels
os.environ["TRTLLM_ENABLE_PDL"] = "1" os.environ["TRTLLM_ENABLE_PDL"] = "1"
if os.environ.get("CUTE_DSL_LOG_LEVEL") is None: if os.environ.get("CUTE_DSL_LOG_LEVEL") is None:
# Default to warning level, to avoid too many logs # Default to warning level, to avoid too many logs
os.environ["CUTE_DSL_LOG_LEVEL"] = "30" os.environ["CUTE_DSL_LOG_LEVEL"] = "30"
if os.environ.get("CUTE_DSL_LOG_TO_CONSOLE") is None: if os.environ.get("CUTE_DSL_LOG_TO_CONSOLE") is None:
# Need to set log to console, otherwise the log level won't take effect # Need to set log to console, otherwise the log level won't take effect
os.environ["CUTE_DSL_LOG_TO_CONSOLE"] = "1" os.environ["CUTE_DSL_LOG_TO_CONSOLE"] = "1"
...@@ -840,7 +843,7 @@ def _launch_subprocesses( ...@@ -840,7 +843,7 @@ def _launch_subprocesses(
if os.getenv("SGLANG_BLOCK_NONZERO_RANK_CHILDREN") == "0": if os.getenv("SGLANG_BLOCK_NONZERO_RANK_CHILDREN") == "0":
# When using `Engine` as a Python API, we don't want to block here. # When using `Engine` as a Python API, we don't want to block here.
return None, None, None return None, None, None, port_args
launch_dummy_health_check_server( launch_dummy_health_check_server(
server_args.host, server_args.port, server_args.enable_metrics server_args.host, server_args.port, server_args.enable_metrics
...@@ -851,7 +854,7 @@ def _launch_subprocesses( ...@@ -851,7 +854,7 @@ def _launch_subprocesses(
logger.error( logger.error(
f"Scheduler or DataParallelController {proc.pid} terminated with {proc.exitcode}" f"Scheduler or DataParallelController {proc.pid} terminated with {proc.exitcode}"
) )
return None, None, None return None, None, None, port_args
# Launch detokenizer process # Launch detokenizer process
detoken_proc = mp.Process( detoken_proc = mp.Process(
...@@ -897,4 +900,4 @@ def _launch_subprocesses( ...@@ -897,4 +900,4 @@ def _launch_subprocesses(
tokenizer_manager.max_req_input_len = scheduler_info["max_req_input_len"] tokenizer_manager.max_req_input_len = scheduler_info["max_req_input_len"]
return tokenizer_manager, template_manager, scheduler_info return tokenizer_manager, template_manager, scheduler_info, port_args
...@@ -20,7 +20,7 @@ This file implements HTTP APIs for the inference engine via fastapi. ...@@ -20,7 +20,7 @@ This file implements HTTP APIs for the inference engine via fastapi.
import asyncio import asyncio
import dataclasses import dataclasses
import logging import logging
import multiprocessing as multiprocessing import multiprocessing
import os import os
import tempfile import tempfile
import threading import threading
...@@ -165,6 +165,7 @@ async def init_multi_tokenizer() -> ServerArgs: ...@@ -165,6 +165,7 @@ async def init_multi_tokenizer() -> ServerArgs:
server_args.api_key is None server_args.api_key is None
), "API key is not supported in multi-tokenizer mode" ), "API key is not supported in multi-tokenizer mode"
# Create a new ipc name for the current process
port_args.tokenizer_ipc_name = ( port_args.tokenizer_ipc_name = (
f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}" f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}"
) )
...@@ -184,6 +185,7 @@ async def init_multi_tokenizer() -> ServerArgs: ...@@ -184,6 +185,7 @@ async def init_multi_tokenizer() -> ServerArgs:
) )
tokenizer_manager.max_req_input_len = scheduler_info["max_req_input_len"] tokenizer_manager.max_req_input_len = scheduler_info["max_req_input_len"]
set_global_state( set_global_state(
_GlobalState( _GlobalState(
tokenizer_manager=tokenizer_manager, tokenizer_manager=tokenizer_manager,
...@@ -192,36 +194,35 @@ async def init_multi_tokenizer() -> ServerArgs: ...@@ -192,36 +194,35 @@ async def init_multi_tokenizer() -> ServerArgs:
) )
) )
if server_args.enable_trace:
process_tracing_init(server_args.oltp_traces_endpoint, "sglang")
if server_args.disaggregation_mode == "null":
thread_label = f"MultiTokenizer-{tokenizer_manager.worker_id}"
trace_set_thread_info(thread_label)
return server_args return server_args
@asynccontextmanager @asynccontextmanager
async def lifespan(fast_api_app: FastAPI): async def lifespan(fast_api_app: FastAPI):
if not getattr(fast_api_app, "is_single_tokenizer_mode", False): if getattr(fast_api_app, "is_single_tokenizer_mode", False):
server_args = fast_api_app.server_args
warmup_thread_args = fast_api_app.warmup_thread_args
thread_label = "Tokenizer"
else:
# Initialize multi-tokenizer support for worker processes # Initialize multi-tokenizer support for worker processes
fast_api_app.server_args: ServerArgs = await init_multi_tokenizer() server_args = await init_multi_tokenizer()
warmup_thread_args = (
# only metrics middleware is supported in multi-tokenizer mode server_args,
worker_pid = os.getpid() None,
if fast_api_app.server_args.enable_metrics: None,
add_prometheus_middleware(app)
enable_func_timer()
logger.info(f"Worker {worker_pid} added prometheus middleware")
fast_api_app.warmup_thread = threading.Thread(
target=_wait_and_warmup,
args=(
fast_api_app.server_args,
None, # pipe_finish_writer not needed in worker
None, # launch_callback not needed in worker
),
) )
thread_label = f"MultiTokenizer-{_global_state.tokenizer_manager.worker_id}"
# Add prometheus middleware
if server_args.enable_metrics:
add_prometheus_middleware(app)
enable_func_timer()
# Init tracing
if server_args.enable_trace:
process_tracing_init(server_args.oltp_traces_endpoint, "sglang")
if server_args.disaggregation_mode == "null":
trace_set_thread_info(thread_label)
# Initialize OpenAI serving handlers # Initialize OpenAI serving handlers
fast_api_app.state.openai_serving_completion = OpenAIServingCompletion( fast_api_app.state.openai_serving_completion = OpenAIServingCompletion(
...@@ -249,8 +250,7 @@ async def lifespan(fast_api_app: FastAPI): ...@@ -249,8 +250,7 @@ async def lifespan(fast_api_app: FastAPI):
_global_state.tokenizer_manager _global_state.tokenizer_manager
) )
server_args: ServerArgs = fast_api_app.server_args # Launch tool server
tool_server = None tool_server = None
if server_args.tool_server == "demo": if server_args.tool_server == "demo":
from sglang.srt.entrypoints.openai.tool_server import DemoToolServer from sglang.srt.entrypoints.openai.tool_server import DemoToolServer
...@@ -274,12 +274,11 @@ async def lifespan(fast_api_app: FastAPI): ...@@ -274,12 +274,11 @@ async def lifespan(fast_api_app: FastAPI):
enable_force_include_usage=True, enable_force_include_usage=True,
tool_server=tool_server, tool_server=tool_server,
) )
except Exception as e: except Exception:
import traceback traceback = get_exception_traceback()
logger.warning(f"Can not initialize OpenAIServingResponses, error: {traceback}")
traceback.print_exc()
logger.warning(f"Can not initialize OpenAIServingResponses, error: {e}")
# Execute custom warmups
if server_args.warmups is not None: if server_args.warmups is not None:
await execute_warmups( await execute_warmups(
server_args.disaggregation_mode, server_args.disaggregation_mode,
...@@ -288,18 +287,18 @@ async def lifespan(fast_api_app: FastAPI): ...@@ -288,18 +287,18 @@ async def lifespan(fast_api_app: FastAPI):
) )
logger.info("Warmup ended") logger.info("Warmup ended")
warmup_thread = getattr(fast_api_app, "warmup_thread", None) # Execute the general warmup
if warmup_thread is not None: warmup_thread = threading.Thread(
warmup_thread.start() target=_wait_and_warmup,
args=warmup_thread_args,
)
warmup_thread.start()
# Start the HTTP server
try: try:
yield yield
finally: finally:
if server_args.tokenizer_worker_num > 1: warmup_thread.join()
pid = os.getpid()
logger.info(f"uvicorn worker {pid} ending...")
warmup_thread.join()
logger.info(f"uvicorn worker {pid} ended.")
# Fast API # Fast API
...@@ -1328,27 +1327,12 @@ def launch_server( ...@@ -1328,27 +1327,12 @@ def launch_server(
3. DetokenizerManager (subprocess): Detokenizes the output tokens and sends the result back to the Tokenizer Manager. 3. DetokenizerManager (subprocess): Detokenizes the output tokens and sends the result back to the Tokenizer Manager.
Note: Note:
1. The HTTP server, Engine, and TokenizerManager both run in the main process. 1. The HTTP server, Engine, and TokenizerManager all run in the main process.
2. Inter-process communication is done through IPC (each process uses a different port) via the ZMQ library. 2. Inter-process communication is done through IPC (each process uses a different port) via the ZMQ library.
""" """
if server_args.tokenizer_worker_num > 1: tokenizer_manager, template_manager, scheduler_info, port_args = (
port_args = PortArgs.init_new(server_args) _launch_subprocesses(server_args=server_args)
port_args.tokenizer_worker_ipc_name = ( )
f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}"
)
tokenizer_manager, template_manager, scheduler_info = _launch_subprocesses(
server_args=server_args, port_args=port_args
)
else:
tokenizer_manager, template_manager, scheduler_info = _launch_subprocesses(
server_args=server_args,
)
if server_args.enable_trace:
process_tracing_init(server_args.oltp_traces_endpoint, "sglang")
if server_args.disaggregation_mode == "null":
thread_label = "Tokenizer"
trace_set_thread_info(thread_label)
set_global_state( set_global_state(
_GlobalState( _GlobalState(
...@@ -1358,40 +1342,45 @@ def launch_server( ...@@ -1358,40 +1342,45 @@ def launch_server(
) )
) )
if server_args.tokenizer_worker_num > 1: # Pass additional arguments to the lifespan function.
multi_tokenizer_args_shm = write_data_for_multi_tokenizer( # They will be used for additional initialization setups.
port_args, if server_args.tokenizer_worker_num == 1:
# If it is single tokenizer mode, we can pass the arguments by attributes of the app object.
app.is_single_tokenizer_mode = True
app.server_args = server_args
app.warmup_thread_args = (
server_args, server_args,
scheduler_info, pipe_finish_writer,
launch_callback,
) )
else:
# Add api key authorization # Add api key authorization
# This is only supported in single tokenizer mode.
if server_args.api_key: if server_args.api_key:
add_api_key_middleware(app, server_args.api_key) add_api_key_middleware(app, server_args.api_key)
else:
# Add prometheus middleware # If it is multi-tokenizer mode, we need to write the arguments to shared memory
if server_args.enable_metrics: # for other worker processes to read.
add_prometheus_middleware(app) app.is_single_tokenizer_mode = False
enable_func_timer() multi_tokenizer_args_shm = write_data_for_multi_tokenizer(
port_args, server_args, scheduler_info
# Send a warmup request - we will create the thread launch it
# in the lifespan after all other warmups have fired.
warmup_thread = threading.Thread(
target=_wait_and_warmup,
args=(
server_args,
pipe_finish_writer,
launch_callback,
),
) )
app.warmup_thread = warmup_thread
try: try:
# Update logging configs # Update logging configs
set_uvicorn_logging_configs() set_uvicorn_logging_configs()
app.server_args = server_args
# Listen for HTTP requests # Listen for HTTP requests
if server_args.tokenizer_worker_num > 1: if server_args.tokenizer_worker_num == 1:
uvicorn.run(
app,
host=server_args.host,
port=server_args.port,
log_level=server_args.log_level_http or server_args.log_level,
timeout_keep_alive=5,
loop="uvloop",
)
else:
from uvicorn.config import LOGGING_CONFIG from uvicorn.config import LOGGING_CONFIG
LOGGING_CONFIG["loggers"]["sglang.srt.entrypoints.http_server"] = { LOGGING_CONFIG["loggers"]["sglang.srt.entrypoints.http_server"] = {
...@@ -1399,7 +1388,6 @@ def launch_server( ...@@ -1399,7 +1388,6 @@ def launch_server(
"level": "INFO", "level": "INFO",
"propagate": False, "propagate": False,
} }
monkey_patch_uvicorn_multiprocessing() monkey_patch_uvicorn_multiprocessing()
uvicorn.run( uvicorn.run(
...@@ -1411,22 +1399,10 @@ def launch_server( ...@@ -1411,22 +1399,10 @@ def launch_server(
loop="uvloop", loop="uvloop",
workers=server_args.tokenizer_worker_num, workers=server_args.tokenizer_worker_num,
) )
else:
app.is_single_tokenizer_mode = True
uvicorn.run(
app,
host=server_args.host,
port=server_args.port,
log_level=server_args.log_level_http or server_args.log_level,
timeout_keep_alive=5,
loop="uvloop",
)
finally: finally:
if server_args.tokenizer_worker_num > 1: if server_args.tokenizer_worker_num > 1:
multi_tokenizer_args_shm.unlink() multi_tokenizer_args_shm.unlink()
_global_state.tokenizer_manager.socket_mapping.clear_all_sockets() _global_state.tokenizer_manager.socket_mapping.clear_all_sockets()
else:
warmup_thread.join()
def _execute_server_warmup( def _execute_server_warmup(
......
...@@ -152,7 +152,6 @@ def initialize_moe_config(server_args: ServerArgs): ...@@ -152,7 +152,6 @@ def initialize_moe_config(server_args: ServerArgs):
def get_moe_a2a_backend() -> MoeA2ABackend: def get_moe_a2a_backend() -> MoeA2ABackend:
global MOE_A2A_BACKEND global MOE_A2A_BACKEND
if MOE_A2A_BACKEND is None: if MOE_A2A_BACKEND is None:
logger.warning("MOE_A2A_BACKEND is not initialized, using default backend")
MOE_A2A_BACKEND = MoeA2ABackend.NONE MOE_A2A_BACKEND = MoeA2ABackend.NONE
return MOE_A2A_BACKEND return MOE_A2A_BACKEND
......
...@@ -13,7 +13,12 @@ from __future__ import annotations ...@@ -13,7 +13,12 @@ from __future__ import annotations
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Mixin class and utils for multi-http-worker mode"""
"""
Mixin classes and utils for multi-http-worker mode
This file uses multiple processes to handle requests and tokenization, reducing the overhead of python and http server.
"""
import asyncio import asyncio
import logging import logging
import multiprocessing as multiprocessing import multiprocessing as multiprocessing
...@@ -566,3 +571,14 @@ def monkey_patch_uvicorn_multiprocessing(timeout: float = 10): ...@@ -566,3 +571,14 @@ def monkey_patch_uvicorn_multiprocessing(timeout: float = 10):
logger.warning( logger.warning(
"uvicorn.supervisors.multiprocess not found, skipping monkey patch" "uvicorn.supervisors.multiprocess not found, skipping monkey patch"
) )
class SenderWrapper:
def __init__(self, port_args: PortArgs, send_to_scheduler: zmq.Socket):
self.port_args = port_args
self.send_to_scheduler = send_to_scheduler
def send_pyobj(self, obj):
if isinstance(obj, BaseReq):
obj.http_worker_ipc = self.port_args.tokenizer_ipc_name
self.send_to_scheduler.send_pyobj(obj)
...@@ -123,7 +123,7 @@ class SchedulerMetricsMixin: ...@@ -123,7 +123,7 @@ class SchedulerMetricsMixin:
token_usage_msg = f"token usage: {token_usage:.2f}, " token_usage_msg = f"token usage: {token_usage:.2f}, "
f = ( f = (
f"Prefill batch [{self.forward_ct + 1}], " f"Prefill batch. "
f"#new-seq: {len(can_run_list)}, " f"#new-seq: {len(can_run_list)}, "
f"#new-token: {adder.log_input_tokens}, " f"#new-token: {adder.log_input_tokens}, "
f"#cached-token: {adder.log_hit_tokens}, " f"#cached-token: {adder.log_hit_tokens}, "
...@@ -246,7 +246,7 @@ class SchedulerMetricsMixin: ...@@ -246,7 +246,7 @@ class SchedulerMetricsMixin:
gap_latency / self.server_args.decode_log_interval gap_latency / self.server_args.decode_log_interval
) )
msg = f"Decode batch [{self.forward_ct}], #running-req: {num_running_reqs}, {token_usage_msg}" msg = f"Decode batch. #running-req: {num_running_reqs}, {token_usage_msg}"
if self.spec_algorithm.is_none(): if self.spec_algorithm.is_none():
spec_accept_length = 0 spec_accept_length = 0
......
...@@ -46,7 +46,6 @@ from sglang.srt.managers.async_dynamic_batch_tokenizer import AsyncDynamicbatchT ...@@ -46,7 +46,6 @@ from sglang.srt.managers.async_dynamic_batch_tokenizer import AsyncDynamicbatchT
from sglang.srt.managers.disagg_service import start_disagg_service from sglang.srt.managers.disagg_service import start_disagg_service
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
AbortReq, AbortReq,
BaseReq,
BatchEmbeddingOutput, BatchEmbeddingOutput,
BatchMultimodalOutput, BatchMultimodalOutput,
BatchStrOutput, BatchStrOutput,
...@@ -171,7 +170,6 @@ class TokenizerManager(TokenizerCommunicatorMixin): ...@@ -171,7 +170,6 @@ class TokenizerManager(TokenizerCommunicatorMixin):
self.context_len = self.model_config.context_len self.context_len = self.model_config.context_len
self.image_token_id = self.model_config.image_token_id self.image_token_id = self.model_config.image_token_id
self.max_req_input_len = None # Will be set later in engine.py self.max_req_input_len = None # Will be set later in engine.py
speculative_algorithm = SpeculativeAlgorithm.from_string( speculative_algorithm = SpeculativeAlgorithm.from_string(
server_args.speculative_algorithm server_args.speculative_algorithm
) )
...@@ -180,9 +178,8 @@ class TokenizerManager(TokenizerCommunicatorMixin): ...@@ -180,9 +178,8 @@ class TokenizerManager(TokenizerCommunicatorMixin):
if speculative_algorithm.is_none() if speculative_algorithm.is_none()
else server_args.speculative_num_draft_tokens else server_args.speculative_num_draft_tokens
) )
# Initialize delimiter text for multi-item scoring (will be set after tokenizer is loaded)
self.multi_item_delimiter_text = None
# Initialize tokenizer and processor
if self.model_config.is_multimodal: if self.model_config.is_multimodal:
import_processors("sglang.srt.multimodal.processors") import_processors("sglang.srt.multimodal.processors")
try: try:
...@@ -237,6 +234,7 @@ class TokenizerManager(TokenizerCommunicatorMixin): ...@@ -237,6 +234,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
revision=server_args.revision, revision=server_args.revision,
) )
self._initialize_multi_item_delimiter_text() self._initialize_multi_item_delimiter_text()
# Initialize async dynamic batch tokenizer if enabled (common for both multimodal and non-multimodal) # Initialize async dynamic batch tokenizer if enabled (common for both multimodal and non-multimodal)
if ( if (
server_args.enable_dynamic_batch_tokenizer server_args.enable_dynamic_batch_tokenizer
...@@ -255,24 +253,20 @@ class TokenizerManager(TokenizerCommunicatorMixin): ...@@ -255,24 +253,20 @@ class TokenizerManager(TokenizerCommunicatorMixin):
self.recv_from_detokenizer = get_zmq_socket( self.recv_from_detokenizer = get_zmq_socket(
context, zmq.PULL, port_args.tokenizer_ipc_name, True context, zmq.PULL, port_args.tokenizer_ipc_name, True
) )
if self.server_args.tokenizer_worker_num > 1: if self.server_args.tokenizer_worker_num == 1:
self.send_to_scheduler = get_zmq_socket(
context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
)
else:
from sglang.srt.managers.multi_tokenizer_mixin import SenderWrapper
# Use tokenizer_worker_ipc_name in multi-tokenizer mode # Use tokenizer_worker_ipc_name in multi-tokenizer mode
send_to_scheduler = get_zmq_socket( send_to_scheduler = get_zmq_socket(
context, zmq.PUSH, port_args.tokenizer_worker_ipc_name, False context, zmq.PUSH, port_args.tokenizer_worker_ipc_name, False
) )
class SenderWrapper:
def send_pyobj(self, obj):
if isinstance(obj, BaseReq):
obj.http_worker_ipc = port_args.tokenizer_ipc_name
send_to_scheduler.send_pyobj(obj)
# Make sure that each request carries the tokenizer_ipc_name for response routing # Make sure that each request carries the tokenizer_ipc_name for response routing
self.send_to_scheduler = SenderWrapper() self.send_to_scheduler = SenderWrapper(port_args, send_to_scheduler)
else:
self.send_to_scheduler = get_zmq_socket(
context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
)
# Request states # Request states
self._chosen_loop = None self._chosen_loop = None
...@@ -320,6 +314,7 @@ class TokenizerManager(TokenizerCommunicatorMixin): ...@@ -320,6 +314,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
# LoRA updates and inference to overlap. # LoRA updates and inference to overlap.
self.lora_update_lock = asyncio.Lock() self.lora_update_lock = asyncio.Lock()
# Disaggregation
self.disaggregation_mode = DisaggregationMode( self.disaggregation_mode = DisaggregationMode(
self.server_args.disaggregation_mode self.server_args.disaggregation_mode
) )
...@@ -389,9 +384,6 @@ class TokenizerManager(TokenizerCommunicatorMixin): ...@@ -389,9 +384,6 @@ class TokenizerManager(TokenizerCommunicatorMixin):
obj.normalize_batch_and_arguments() obj.normalize_batch_and_arguments()
if self.server_args.tokenizer_worker_num > 1: if self.server_args.tokenizer_worker_num > 1:
from sglang.srt.managers.multi_tokenizer_mixin import TokenizerWorker
assert isinstance(self, TokenizerWorker)
self._attach_multi_http_worker_info(obj) self._attach_multi_http_worker_info(obj)
if self.enable_trace: if self.enable_trace:
......
...@@ -3745,6 +3745,13 @@ class PortArgs: ...@@ -3745,6 +3745,13 @@ class PortArgs:
else: else:
nccl_port = server_args.nccl_port nccl_port = server_args.nccl_port
if server_args.tokenizer_worker_num > 1:
tokenizer_worker_ipc_name = (
f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}"
)
else:
tokenizer_worker_ipc_name = None
if not server_args.enable_dp_attention: if not server_args.enable_dp_attention:
# Normal case, use IPC within a single node # Normal case, use IPC within a single node
return PortArgs( return PortArgs(
...@@ -3754,7 +3761,7 @@ class PortArgs: ...@@ -3754,7 +3761,7 @@ class PortArgs:
nccl_port=nccl_port, nccl_port=nccl_port,
rpc_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}", rpc_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
metrics_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}", metrics_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
tokenizer_worker_ipc_name=None, tokenizer_worker_ipc_name=tokenizer_worker_ipc_name,
) )
else: else:
# DP attention. Use TCP + port to handle both single-node and multi-node. # DP attention. Use TCP + port to handle both single-node and multi-node.
...@@ -3788,7 +3795,7 @@ class PortArgs: ...@@ -3788,7 +3795,7 @@ class PortArgs:
nccl_port=nccl_port, nccl_port=nccl_port,
rpc_ipc_name=f"tcp://{dist_init_host}:{rpc_port}", rpc_ipc_name=f"tcp://{dist_init_host}:{rpc_port}",
metrics_ipc_name=f"tcp://{dist_init_host}:{metrics_ipc_name}", metrics_ipc_name=f"tcp://{dist_init_host}:{metrics_ipc_name}",
tokenizer_worker_ipc_name=None, tokenizer_worker_ipc_name=tokenizer_worker_ipc_name,
) )
......
...@@ -56,7 +56,6 @@ from json import JSONDecodeError ...@@ -56,7 +56,6 @@ from json import JSONDecodeError
from multiprocessing.reduction import ForkingPickler from multiprocessing.reduction import ForkingPickler
from pathlib import Path from pathlib import Path
from typing import ( from typing import (
TYPE_CHECKING,
Any, Any,
Callable, Callable,
Dict, Dict,
...@@ -94,9 +93,6 @@ from typing_extensions import Literal ...@@ -94,9 +93,6 @@ from typing_extensions import Literal
from sglang.srt.environ import envs from sglang.srt.environ import envs
from sglang.srt.metrics.func_timer import enable_func_timer from sglang.srt.metrics.func_timer import enable_func_timer
if TYPE_CHECKING:
pass
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
show_time_cost = False show_time_cost = False
...@@ -1106,9 +1102,9 @@ def add_api_key_middleware(app, api_key: str): ...@@ -1106,9 +1102,9 @@ def add_api_key_middleware(app, api_key: str):
async def authentication(request, call_next): async def authentication(request, call_next):
if request.method == "OPTIONS": if request.method == "OPTIONS":
return await call_next(request) return await call_next(request)
if request.url.path.startswith("/health"): if request.url.path.startswith("/health") or request.url.path.startswith(
return await call_next(request) "/metrics"
if request.url.path.startswith("/metrics"): ):
return await call_next(request) return await call_next(request)
if request.headers.get("Authorization") != "Bearer " + api_key: if request.headers.get("Authorization") != "Bearer " + api_key:
return ORJSONResponse(content={"error": "Unauthorized"}, status_code=401) return ORJSONResponse(content={"error": "Unauthorized"}, status_code=401)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment