Unverified Commit 8e70064c authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Clean up server launch code and multi tokenizer (#12132)

parent d98b81e2
......@@ -806,7 +806,7 @@ def _get_and_verify_dtype(
) -> torch.dtype:
# NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
# because config.torch_dtype can be None.
config_dtype = getattr(config, "torch_dtype", None)
config_dtype = getattr(config, "dtype", None)
if isinstance(config_dtype, str):
config_dtype = _STR_DTYPE_TO_TORCH_DTYPE.get(config_dtype, None)
if config_dtype is None:
......
......@@ -101,7 +101,7 @@ class Engine(EngineBase):
Note:
1. The HTTP server, Engine, and TokenizerManager all run in the main process.
2. Inter-process communication (IPC) is handled via the ZMQ library, with each process using a different port.
2. Inter-process communication is done through IPC (each process uses a different port) via the ZMQ library.
"""
def __init__(self, **kwargs):
......@@ -109,6 +109,8 @@ class Engine(EngineBase):
The arguments of this function is the same as `sglang/srt/server_args.py::ServerArgs`.
Please refer to `ServerArgs` for the documentation.
"""
# Parse server_args
if "server_args" in kwargs:
# Directly load server_args
server_args = kwargs["server_args"]
......@@ -118,29 +120,28 @@ class Engine(EngineBase):
# Do not print logs by default
kwargs["log_level"] = "error"
server_args = ServerArgs(**kwargs)
self.server_args = server_args
logger.info(f"{server_args=}")
# Shutdown the subprocesses automatically when the program exits
atexit.register(self.shutdown)
# Allocate ports for inter-process communications
self.port_args = PortArgs.init_new(server_args)
logger.info(f"{server_args=}")
# Launch subprocesses
tokenizer_manager, template_manager, scheduler_info = _launch_subprocesses(
server_args=server_args,
port_args=self.port_args,
tokenizer_manager, template_manager, scheduler_info, port_args = (
_launch_subprocesses(server_args=server_args)
)
self.server_args = server_args
self.tokenizer_manager = tokenizer_manager
self.template_manager = template_manager
self.scheduler_info = scheduler_info
self.port_args = port_args
# Initialize ZMQ sockets
context = zmq.Context(2)
self.send_to_rpc = get_zmq_socket(
context, zmq.DEALER, self.port_args.rpc_ipc_name, True
)
# Enable tracing
if server_args.enable_trace:
process_tracing_init(server_args.oltp_traces_endpoint, "sglang")
if server_args.disaggregation_mode == "null":
......@@ -672,15 +673,17 @@ def _set_envs_and_config(server_args: ServerArgs):
os.environ["NCCL_CUMEM_ENABLE"] = str(int(server_args.enable_symm_mem))
if not server_args.enable_symm_mem:
os.environ["NCCL_NVLS_ENABLE"] = str(int(server_args.enable_nccl_nvls))
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4"
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "8"
os.environ["CUDA_MODULE_LOADING"] = "AUTO"
# flashinfer uses this environment variable for various kernels from MoE to quant kernels
if os.environ.get("TRTLLM_ENABLE_PDL", "1") != "0":
# flashinfer uses this environment variable for various kernels from MoE to quant kernels
os.environ["TRTLLM_ENABLE_PDL"] = "1"
if os.environ.get("CUTE_DSL_LOG_LEVEL") is None:
# Default to warning level, to avoid too many logs
os.environ["CUTE_DSL_LOG_LEVEL"] = "30"
if os.environ.get("CUTE_DSL_LOG_TO_CONSOLE") is None:
# Need to set log to console, otherwise the log level won't take effect
os.environ["CUTE_DSL_LOG_TO_CONSOLE"] = "1"
......@@ -840,7 +843,7 @@ def _launch_subprocesses(
if os.getenv("SGLANG_BLOCK_NONZERO_RANK_CHILDREN") == "0":
# When using `Engine` as a Python API, we don't want to block here.
return None, None, None
return None, None, None, port_args
launch_dummy_health_check_server(
server_args.host, server_args.port, server_args.enable_metrics
......@@ -851,7 +854,7 @@ def _launch_subprocesses(
logger.error(
f"Scheduler or DataParallelController {proc.pid} terminated with {proc.exitcode}"
)
return None, None, None
return None, None, None, port_args
# Launch detokenizer process
detoken_proc = mp.Process(
......@@ -897,4 +900,4 @@ def _launch_subprocesses(
tokenizer_manager.max_req_input_len = scheduler_info["max_req_input_len"]
return tokenizer_manager, template_manager, scheduler_info
return tokenizer_manager, template_manager, scheduler_info, port_args
......@@ -20,7 +20,7 @@ This file implements HTTP APIs for the inference engine via fastapi.
import asyncio
import dataclasses
import logging
import multiprocessing as multiprocessing
import multiprocessing
import os
import tempfile
import threading
......@@ -165,6 +165,7 @@ async def init_multi_tokenizer() -> ServerArgs:
server_args.api_key is None
), "API key is not supported in multi-tokenizer mode"
# Create a new ipc name for the current process
port_args.tokenizer_ipc_name = (
f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}"
)
......@@ -184,6 +185,7 @@ async def init_multi_tokenizer() -> ServerArgs:
)
tokenizer_manager.max_req_input_len = scheduler_info["max_req_input_len"]
set_global_state(
_GlobalState(
tokenizer_manager=tokenizer_manager,
......@@ -192,36 +194,35 @@ async def init_multi_tokenizer() -> ServerArgs:
)
)
if server_args.enable_trace:
process_tracing_init(server_args.oltp_traces_endpoint, "sglang")
if server_args.disaggregation_mode == "null":
thread_label = f"MultiTokenizer-{tokenizer_manager.worker_id}"
trace_set_thread_info(thread_label)
return server_args
@asynccontextmanager
async def lifespan(fast_api_app: FastAPI):
if not getattr(fast_api_app, "is_single_tokenizer_mode", False):
if getattr(fast_api_app, "is_single_tokenizer_mode", False):
server_args = fast_api_app.server_args
warmup_thread_args = fast_api_app.warmup_thread_args
thread_label = "Tokenizer"
else:
# Initialize multi-tokenizer support for worker processes
fast_api_app.server_args: ServerArgs = await init_multi_tokenizer()
# only metrics middleware is supported in multi-tokenizer mode
worker_pid = os.getpid()
if fast_api_app.server_args.enable_metrics:
add_prometheus_middleware(app)
enable_func_timer()
logger.info(f"Worker {worker_pid} added prometheus middleware")
fast_api_app.warmup_thread = threading.Thread(
target=_wait_and_warmup,
args=(
fast_api_app.server_args,
None, # pipe_finish_writer not needed in worker
None, # launch_callback not needed in worker
),
server_args = await init_multi_tokenizer()
warmup_thread_args = (
server_args,
None,
None,
)
thread_label = f"MultiTokenizer-{_global_state.tokenizer_manager.worker_id}"
# Add prometheus middleware
if server_args.enable_metrics:
add_prometheus_middleware(app)
enable_func_timer()
# Init tracing
if server_args.enable_trace:
process_tracing_init(server_args.oltp_traces_endpoint, "sglang")
if server_args.disaggregation_mode == "null":
trace_set_thread_info(thread_label)
# Initialize OpenAI serving handlers
fast_api_app.state.openai_serving_completion = OpenAIServingCompletion(
......@@ -249,8 +250,7 @@ async def lifespan(fast_api_app: FastAPI):
_global_state.tokenizer_manager
)
server_args: ServerArgs = fast_api_app.server_args
# Launch tool server
tool_server = None
if server_args.tool_server == "demo":
from sglang.srt.entrypoints.openai.tool_server import DemoToolServer
......@@ -274,12 +274,11 @@ async def lifespan(fast_api_app: FastAPI):
enable_force_include_usage=True,
tool_server=tool_server,
)
except Exception as e:
import traceback
traceback.print_exc()
logger.warning(f"Can not initialize OpenAIServingResponses, error: {e}")
except Exception:
traceback = get_exception_traceback()
logger.warning(f"Can not initialize OpenAIServingResponses, error: {traceback}")
# Execute custom warmups
if server_args.warmups is not None:
await execute_warmups(
server_args.disaggregation_mode,
......@@ -288,18 +287,18 @@ async def lifespan(fast_api_app: FastAPI):
)
logger.info("Warmup ended")
warmup_thread = getattr(fast_api_app, "warmup_thread", None)
if warmup_thread is not None:
warmup_thread.start()
# Execute the general warmup
warmup_thread = threading.Thread(
target=_wait_and_warmup,
args=warmup_thread_args,
)
warmup_thread.start()
# Start the HTTP server
try:
yield
finally:
if server_args.tokenizer_worker_num > 1:
pid = os.getpid()
logger.info(f"uvicorn worker {pid} ending...")
warmup_thread.join()
logger.info(f"uvicorn worker {pid} ended.")
warmup_thread.join()
# Fast API
......@@ -1328,27 +1327,12 @@ def launch_server(
3. DetokenizerManager (subprocess): Detokenizes the output tokens and sends the result back to the Tokenizer Manager.
Note:
1. The HTTP server, Engine, and TokenizerManager both run in the main process.
1. The HTTP server, Engine, and TokenizerManager all run in the main process.
2. Inter-process communication is done through IPC (each process uses a different port) via the ZMQ library.
"""
if server_args.tokenizer_worker_num > 1:
port_args = PortArgs.init_new(server_args)
port_args.tokenizer_worker_ipc_name = (
f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}"
)
tokenizer_manager, template_manager, scheduler_info = _launch_subprocesses(
server_args=server_args, port_args=port_args
)
else:
tokenizer_manager, template_manager, scheduler_info = _launch_subprocesses(
server_args=server_args,
)
if server_args.enable_trace:
process_tracing_init(server_args.oltp_traces_endpoint, "sglang")
if server_args.disaggregation_mode == "null":
thread_label = "Tokenizer"
trace_set_thread_info(thread_label)
tokenizer_manager, template_manager, scheduler_info, port_args = (
_launch_subprocesses(server_args=server_args)
)
set_global_state(
_GlobalState(
......@@ -1358,40 +1342,45 @@ def launch_server(
)
)
if server_args.tokenizer_worker_num > 1:
multi_tokenizer_args_shm = write_data_for_multi_tokenizer(
port_args,
# Pass additional arguments to the lifespan function.
# They will be used for additional initialization setups.
if server_args.tokenizer_worker_num == 1:
# If it is single tokenizer mode, we can pass the arguments by attributes of the app object.
app.is_single_tokenizer_mode = True
app.server_args = server_args
app.warmup_thread_args = (
server_args,
scheduler_info,
pipe_finish_writer,
launch_callback,
)
else:
# Add api key authorization
# This is only supported in single tokenizer mode.
if server_args.api_key:
add_api_key_middleware(app, server_args.api_key)
# Add prometheus middleware
if server_args.enable_metrics:
add_prometheus_middleware(app)
enable_func_timer()
# Send a warmup request - we will create the thread launch it
# in the lifespan after all other warmups have fired.
warmup_thread = threading.Thread(
target=_wait_and_warmup,
args=(
server_args,
pipe_finish_writer,
launch_callback,
),
else:
# If it is multi-tokenizer mode, we need to write the arguments to shared memory
# for other worker processes to read.
app.is_single_tokenizer_mode = False
multi_tokenizer_args_shm = write_data_for_multi_tokenizer(
port_args, server_args, scheduler_info
)
app.warmup_thread = warmup_thread
try:
# Update logging configs
set_uvicorn_logging_configs()
app.server_args = server_args
# Listen for HTTP requests
if server_args.tokenizer_worker_num > 1:
if server_args.tokenizer_worker_num == 1:
uvicorn.run(
app,
host=server_args.host,
port=server_args.port,
log_level=server_args.log_level_http or server_args.log_level,
timeout_keep_alive=5,
loop="uvloop",
)
else:
from uvicorn.config import LOGGING_CONFIG
LOGGING_CONFIG["loggers"]["sglang.srt.entrypoints.http_server"] = {
......@@ -1399,7 +1388,6 @@ def launch_server(
"level": "INFO",
"propagate": False,
}
monkey_patch_uvicorn_multiprocessing()
uvicorn.run(
......@@ -1411,22 +1399,10 @@ def launch_server(
loop="uvloop",
workers=server_args.tokenizer_worker_num,
)
else:
app.is_single_tokenizer_mode = True
uvicorn.run(
app,
host=server_args.host,
port=server_args.port,
log_level=server_args.log_level_http or server_args.log_level,
timeout_keep_alive=5,
loop="uvloop",
)
finally:
if server_args.tokenizer_worker_num > 1:
multi_tokenizer_args_shm.unlink()
_global_state.tokenizer_manager.socket_mapping.clear_all_sockets()
else:
warmup_thread.join()
def _execute_server_warmup(
......
......@@ -152,7 +152,6 @@ def initialize_moe_config(server_args: ServerArgs):
def get_moe_a2a_backend() -> MoeA2ABackend:
global MOE_A2A_BACKEND
if MOE_A2A_BACKEND is None:
logger.warning("MOE_A2A_BACKEND is not initialized, using default backend")
MOE_A2A_BACKEND = MoeA2ABackend.NONE
return MOE_A2A_BACKEND
......
......@@ -13,7 +13,12 @@ from __future__ import annotations
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Mixin class and utils for multi-http-worker mode"""
"""
Mixin classes and utils for multi-http-worker mode
This file uses multiple processes to handle requests and tokenization, reducing the overhead of python and http server.
"""
import asyncio
import logging
import multiprocessing as multiprocessing
......@@ -566,3 +571,14 @@ def monkey_patch_uvicorn_multiprocessing(timeout: float = 10):
logger.warning(
"uvicorn.supervisors.multiprocess not found, skipping monkey patch"
)
class SenderWrapper:
def __init__(self, port_args: PortArgs, send_to_scheduler: zmq.Socket):
self.port_args = port_args
self.send_to_scheduler = send_to_scheduler
def send_pyobj(self, obj):
if isinstance(obj, BaseReq):
obj.http_worker_ipc = self.port_args.tokenizer_ipc_name
self.send_to_scheduler.send_pyobj(obj)
......@@ -123,7 +123,7 @@ class SchedulerMetricsMixin:
token_usage_msg = f"token usage: {token_usage:.2f}, "
f = (
f"Prefill batch [{self.forward_ct + 1}], "
f"Prefill batch. "
f"#new-seq: {len(can_run_list)}, "
f"#new-token: {adder.log_input_tokens}, "
f"#cached-token: {adder.log_hit_tokens}, "
......@@ -246,7 +246,7 @@ class SchedulerMetricsMixin:
gap_latency / self.server_args.decode_log_interval
)
msg = f"Decode batch [{self.forward_ct}], #running-req: {num_running_reqs}, {token_usage_msg}"
msg = f"Decode batch. #running-req: {num_running_reqs}, {token_usage_msg}"
if self.spec_algorithm.is_none():
spec_accept_length = 0
......
......@@ -46,7 +46,6 @@ from sglang.srt.managers.async_dynamic_batch_tokenizer import AsyncDynamicbatchT
from sglang.srt.managers.disagg_service import start_disagg_service
from sglang.srt.managers.io_struct import (
AbortReq,
BaseReq,
BatchEmbeddingOutput,
BatchMultimodalOutput,
BatchStrOutput,
......@@ -171,7 +170,6 @@ class TokenizerManager(TokenizerCommunicatorMixin):
self.context_len = self.model_config.context_len
self.image_token_id = self.model_config.image_token_id
self.max_req_input_len = None # Will be set later in engine.py
speculative_algorithm = SpeculativeAlgorithm.from_string(
server_args.speculative_algorithm
)
......@@ -180,9 +178,8 @@ class TokenizerManager(TokenizerCommunicatorMixin):
if speculative_algorithm.is_none()
else server_args.speculative_num_draft_tokens
)
# Initialize delimiter text for multi-item scoring (will be set after tokenizer is loaded)
self.multi_item_delimiter_text = None
# Initialize tokenizer and processor
if self.model_config.is_multimodal:
import_processors("sglang.srt.multimodal.processors")
try:
......@@ -237,6 +234,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
revision=server_args.revision,
)
self._initialize_multi_item_delimiter_text()
# Initialize async dynamic batch tokenizer if enabled (common for both multimodal and non-multimodal)
if (
server_args.enable_dynamic_batch_tokenizer
......@@ -255,24 +253,20 @@ class TokenizerManager(TokenizerCommunicatorMixin):
self.recv_from_detokenizer = get_zmq_socket(
context, zmq.PULL, port_args.tokenizer_ipc_name, True
)
if self.server_args.tokenizer_worker_num > 1:
if self.server_args.tokenizer_worker_num == 1:
self.send_to_scheduler = get_zmq_socket(
context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
)
else:
from sglang.srt.managers.multi_tokenizer_mixin import SenderWrapper
# Use tokenizer_worker_ipc_name in multi-tokenizer mode
send_to_scheduler = get_zmq_socket(
context, zmq.PUSH, port_args.tokenizer_worker_ipc_name, False
)
class SenderWrapper:
def send_pyobj(self, obj):
if isinstance(obj, BaseReq):
obj.http_worker_ipc = port_args.tokenizer_ipc_name
send_to_scheduler.send_pyobj(obj)
# Make sure that each request carries the tokenizer_ipc_name for response routing
self.send_to_scheduler = SenderWrapper()
else:
self.send_to_scheduler = get_zmq_socket(
context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
)
self.send_to_scheduler = SenderWrapper(port_args, send_to_scheduler)
# Request states
self._chosen_loop = None
......@@ -320,6 +314,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
# LoRA updates and inference to overlap.
self.lora_update_lock = asyncio.Lock()
# Disaggregation
self.disaggregation_mode = DisaggregationMode(
self.server_args.disaggregation_mode
)
......@@ -389,9 +384,6 @@ class TokenizerManager(TokenizerCommunicatorMixin):
obj.normalize_batch_and_arguments()
if self.server_args.tokenizer_worker_num > 1:
from sglang.srt.managers.multi_tokenizer_mixin import TokenizerWorker
assert isinstance(self, TokenizerWorker)
self._attach_multi_http_worker_info(obj)
if self.enable_trace:
......
......@@ -3745,6 +3745,13 @@ class PortArgs:
else:
nccl_port = server_args.nccl_port
if server_args.tokenizer_worker_num > 1:
tokenizer_worker_ipc_name = (
f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}"
)
else:
tokenizer_worker_ipc_name = None
if not server_args.enable_dp_attention:
# Normal case, use IPC within a single node
return PortArgs(
......@@ -3754,7 +3761,7 @@ class PortArgs:
nccl_port=nccl_port,
rpc_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
metrics_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
tokenizer_worker_ipc_name=None,
tokenizer_worker_ipc_name=tokenizer_worker_ipc_name,
)
else:
# DP attention. Use TCP + port to handle both single-node and multi-node.
......@@ -3788,7 +3795,7 @@ class PortArgs:
nccl_port=nccl_port,
rpc_ipc_name=f"tcp://{dist_init_host}:{rpc_port}",
metrics_ipc_name=f"tcp://{dist_init_host}:{metrics_ipc_name}",
tokenizer_worker_ipc_name=None,
tokenizer_worker_ipc_name=tokenizer_worker_ipc_name,
)
......
......@@ -56,7 +56,6 @@ from json import JSONDecodeError
from multiprocessing.reduction import ForkingPickler
from pathlib import Path
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
......@@ -94,9 +93,6 @@ from typing_extensions import Literal
from sglang.srt.environ import envs
from sglang.srt.metrics.func_timer import enable_func_timer
if TYPE_CHECKING:
pass
logger = logging.getLogger(__name__)
show_time_cost = False
......@@ -1106,9 +1102,9 @@ def add_api_key_middleware(app, api_key: str):
async def authentication(request, call_next):
if request.method == "OPTIONS":
return await call_next(request)
if request.url.path.startswith("/health"):
return await call_next(request)
if request.url.path.startswith("/metrics"):
if request.url.path.startswith("/health") or request.url.path.startswith(
"/metrics"
):
return await call_next(request)
if request.headers.get("Authorization") != "Bearer " + api_key:
return ORJSONResponse(content={"error": "Unauthorized"}, status_code=401)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment