Unverified Commit 93d69061 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Simplify the process launch code in server.py (#2923)

parent e00e5385
...@@ -44,7 +44,6 @@ import uvloop ...@@ -44,7 +44,6 @@ import uvloop
from fastapi import FastAPI, File, Form, Request, UploadFile from fastapi import FastAPI, File, Form, Request, UploadFile
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import ORJSONResponse, Response, StreamingResponse from fastapi.responses import ORJSONResponse, Response, StreamingResponse
from uvicorn.config import LOGGING_CONFIG
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.hf_transformers_utils import get_tokenizer
...@@ -97,6 +96,7 @@ from sglang.srt.utils import ( ...@@ -97,6 +96,7 @@ from sglang.srt.utils import (
prepare_model_and_tokenizer, prepare_model_and_tokenizer,
set_prometheus_multiproc_dir, set_prometheus_multiproc_dir,
set_ulimit, set_ulimit,
set_uvicorn_logging_configs,
) )
from sglang.utils import get_exception_traceback from sglang.utils import get_exception_traceback
from sglang.version import __version__ from sglang.version import __version__
...@@ -474,13 +474,13 @@ def launch_engine( ...@@ -474,13 +474,13 @@ def launch_engine(
server_args.model_path, server_args.tokenizer_path server_args.model_path, server_args.tokenizer_path
) )
memory_saver_adapter = TorchMemorySaverAdapter.create( scheduler_procs = []
enable=server_args.enable_memory_saver
)
if server_args.dp_size == 1: if server_args.dp_size == 1:
# Launch tensor parallel scheduler processes # Launch tensor parallel scheduler processes
scheduler_procs = [] memory_saver_adapter = TorchMemorySaverAdapter.create(
enable=server_args.enable_memory_saver
)
scheduler_pipe_readers = [] scheduler_pipe_readers = []
tp_size_per_node = server_args.tp_size // server_args.nnodes tp_size_per_node = server_args.tp_size // server_args.nnodes
tp_rank_range = range( tp_rank_range = range(
...@@ -498,12 +498,6 @@ def launch_engine( ...@@ -498,12 +498,6 @@ def launch_engine(
proc.start() proc.start()
scheduler_procs.append(proc) scheduler_procs.append(proc)
scheduler_pipe_readers.append(reader) scheduler_pipe_readers.append(reader)
if server_args.node_rank >= 1:
# For other nodes, they do not need to run tokenizer or detokenizer,
# so they can just wait here.
for proc in scheduler_procs:
proc.join()
else: else:
# Launch the data parallel controller # Launch the data parallel controller
reader, writer = mp.Pipe(duplex=False) reader, writer = mp.Pipe(duplex=False)
...@@ -512,8 +506,27 @@ def launch_engine( ...@@ -512,8 +506,27 @@ def launch_engine(
target=run_data_parallel_controller_process, target=run_data_parallel_controller_process,
args=(server_args, port_args, writer), args=(server_args, port_args, writer),
) )
with memory_saver_adapter.configure_subprocess(): proc.start()
proc.start() scheduler_procs.append(proc)
if server_args.node_rank >= 1:
# In multi-node cases, non-zero rank nodes do not need to run tokenizer or detokenizer,
# so they can just wait here.
for reader in scheduler_pipe_readers:
data = reader.recv()
assert data["status"] == "ready"
if os.getenv("SGLANG_BLOCK_NONZERO_RANK_CHILDREN") == "0":
# When using `Engine` as a Python API, we don't want to block here.
return
for proc in scheduler_procs:
proc.join()
logger.error(
f"Scheduler or DataParallelController {proc.pid} terminated with {proc.exitcode}"
)
return
# Launch detokenizer process # Launch detokenizer process
detoken_proc = mp.Process( detoken_proc = mp.Process(
...@@ -597,14 +610,7 @@ def launch_server( ...@@ -597,14 +610,7 @@ def launch_server(
try: try:
# Update logging configs # Update logging configs
LOGGING_CONFIG["formatters"]["default"][ set_uvicorn_logging_configs()
"fmt"
] = "[%(asctime)s] %(levelprefix)s %(message)s"
LOGGING_CONFIG["formatters"]["default"]["datefmt"] = "%Y-%m-%d %H:%M:%S"
LOGGING_CONFIG["formatters"]["access"][
"fmt"
] = '[%(asctime)s] %(levelprefix)s %(client_addr)s - "%(request_line)s" %(status_code)s'
LOGGING_CONFIG["formatters"]["access"]["datefmt"] = "%Y-%m-%d %H:%M:%S"
# Listen for HTTP requests # Listen for HTTP requests
uvicorn.run( uvicorn.run(
......
...@@ -59,6 +59,7 @@ from triton.runtime.cache import ( ...@@ -59,6 +59,7 @@ from triton.runtime.cache import (
default_dump_dir, default_dump_dir,
default_override_dir, default_override_dir,
) )
from uvicorn.config import LOGGING_CONFIG
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -1404,3 +1405,14 @@ def nullable_str(val: str): ...@@ -1404,3 +1405,14 @@ def nullable_str(val: str):
if not val or val == "None": if not val or val == "None":
return None return None
return val return val
def set_uvicorn_logging_configs():
LOGGING_CONFIG["formatters"]["default"][
"fmt"
] = "[%(asctime)s] %(levelprefix)s %(message)s"
LOGGING_CONFIG["formatters"]["default"]["datefmt"] = "%Y-%m-%d %H:%M:%S"
LOGGING_CONFIG["formatters"]["access"][
"fmt"
] = '[%(asctime)s] %(levelprefix)s %(client_addr)s - "%(request_line)s" %(status_code)s'
LOGGING_CONFIG["formatters"]["access"]["datefmt"] = "%Y-%m-%d %H:%M:%S"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment