Unverified Commit 46d44318 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Add a new api configure_logging to allow dumping the requests (#2875)

parent 923f5183
...@@ -336,7 +336,7 @@ loadTracer.sh python3 -m sglang.launch_server \ ...@@ -336,7 +336,7 @@ loadTracer.sh python3 -m sglang.launch_server \
--model-path /sgl-workspace/sglang/dummy_grok1 \ --model-path /sgl-workspace/sglang/dummy_grok1 \
--tokenizer-path Xenova/grok-1-tokenizer \ --tokenizer-path Xenova/grok-1-tokenizer \
--load-format dummy \ --load-format dummy \
--quant fp8 \ --quantization fp8 \
--tp 8 \ --tp 8 \
--port 30000 \ --port 30000 \
--disable-radix-cache 2>&1 | tee "$LOGFILE" --disable-radix-cache 2>&1 | tee "$LOGFILE"
......
...@@ -14,7 +14,7 @@ loadTracer.sh python3 -m sglang.launch_server \ ...@@ -14,7 +14,7 @@ loadTracer.sh python3 -m sglang.launch_server \
--model-path /sgl-workspace/sglang/dummy_grok1 \ --model-path /sgl-workspace/sglang/dummy_grok1 \
--tokenizer-path Xenova/grok-1-tokenizer \ --tokenizer-path Xenova/grok-1-tokenizer \
--load-format dummy \ --load-format dummy \
--quant fp8 \ --quantization fp8 \
--tp 8 \ --tp 8 \
--port 30000 \ --port 30000 \
--disable-radix-cache 2>&1 | tee "$LOGFILE" --disable-radix-cache 2>&1 | tee "$LOGFILE"
...@@ -104,7 +104,7 @@ To maximize moe kernel efficiency, need to use below scripts to find out the bes ...@@ -104,7 +104,7 @@ To maximize moe kernel efficiency, need to use below scripts to find out the bes
```bash ```bash
#Tuning #Tuning
#for example, we have one case like this "python3 -m sglang.bench_latency --model dummy_grok1/ --load-format dummy --tokenizer-path Xenova/grok-1-tokenizer --tp 8 --batch-size 32 --input 1024 --output 8 --attention-backend triton --sampling-backend pytorch --quant fp" to run, it defined batch-size 32 input lenth 1024 and output length 8, from "--batch" in moe view point, the prefill batch is 32*1024 = 32768, the decode batch is 32*1(only one output token generated in each run). #for example, we have one case like this "python3 -m sglang.bench_latency --model dummy_grok1/ --load-format dummy --tokenizer-path Xenova/grok-1-tokenizer --tp 8 --batch-size 32 --input 1024 --output 8 --attention-backend triton --sampling-backend pytorch --quantization fp8" to run, it defined batch-size 32 input lenth 1024 and output length 8, from "--batch" in moe view point, the prefill batch is 32*1024 = 32768, the decode batch is 32*1(only one output token generated in each run).
#so we can tune decode moe use below command #so we can tune decode moe use below command
python benchmark_moe_rocm.py --model grok1 --tp-size 8 --dtype float8 --batch "32" python benchmark_moe_rocm.py --model grok1 --tp-size 8 --dtype float8 --batch "32"
# and use this command to tune prefill moe # and use this command to tune prefill moe
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
# wget https://huggingface.co/neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8/resolve/main/tokenizer_config.json # wget https://huggingface.co/neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8/resolve/main/tokenizer_config.json
# Launch sglang # Launch sglang
# python -m sglang.launch_server --model-path ~/llama-3.1-405b-fp8-dummy/ --load-format dummy --tp 8 --quant fp8 --disable-radix --mem-frac 0.87 # python -m sglang.launch_server --model-path ~/llama-3.1-405b-fp8-dummy/ --load-format dummy --tp 8 --quantization fp8 --disable-radix --mem-frac 0.87
# offline # offline
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 3000 --random-input 1024 --random-output 1024 > sglang_log11 python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 3000 --random-input 1024 --random-output 1024 > sglang_log11
......
"""
Copyright 2023-2025 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
"""
Configure the logging settings of a server.
Usage:
python3 -m sglang.srt.managers.configure_logging --url http://localhost:30000
"""
import argparse
import requests
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--url", type=str, default="http://localhost:30000")
parser.add_argument(
"--dump-requests-folder", type=str, default="/tmp/sglang_request_dump"
)
parser.add_argument("--dump-requests-threshold", type=int, default=1000)
args = parser.parse_args()
response = requests.post(
args.url + "/configure_logging",
json={
"dump_requests_folder": args.dump_requests_folder,
"dump_requests_threshold": args.dump_requests_threshold,
},
)
assert response.status_code == 200
...@@ -488,6 +488,13 @@ class ProfileReq(Enum): ...@@ -488,6 +488,13 @@ class ProfileReq(Enum):
STOP_PROFILE = 2 STOP_PROFILE = 2
@dataclass
class ConfigureLoggingReq:
log_requests: Optional[bool] = None
dump_requests_folder: Optional[str] = None
dump_requests_threshold: Optional[int] = None
@dataclass @dataclass
class OpenSessionReqInput: class OpenSessionReqInput:
capacity_of_str_len: int capacity_of_str_len: int
......
...@@ -82,6 +82,7 @@ from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerSta ...@@ -82,6 +82,7 @@ from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerSta
from sglang.srt.model_executor.forward_batch_info import ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.srt.utils import ( from sglang.srt.utils import (
broadcast_pyobj, broadcast_pyobj,
configure_logger, configure_logger,
...@@ -92,7 +93,6 @@ from sglang.srt.utils import ( ...@@ -92,7 +93,6 @@ from sglang.srt.utils import (
set_random_seed, set_random_seed,
suppress_other_loggers, suppress_other_loggers,
) )
from sglang.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.utils import get_exception_traceback from sglang.utils import get_exception_traceback
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
...@@ -18,10 +18,12 @@ import copy ...@@ -18,10 +18,12 @@ import copy
import dataclasses import dataclasses
import logging import logging
import os import os
import pickle
import signal import signal
import sys import sys
import time import time
import uuid import uuid
from datetime import datetime
from typing import Any, Awaitable, Dict, Generic, List, Optional, Tuple, TypeVar, Union from typing import Any, Awaitable, Dict, Generic, List, Optional, Tuple, TypeVar, Union
import fastapi import fastapi
...@@ -43,6 +45,7 @@ from sglang.srt.managers.io_struct import ( ...@@ -43,6 +45,7 @@ from sglang.srt.managers.io_struct import (
BatchStrOut, BatchStrOut,
BatchTokenIDOut, BatchTokenIDOut,
CloseSessionReqInput, CloseSessionReqInput,
ConfigureLoggingReq,
EmbeddingReqInput, EmbeddingReqInput,
FlushCacheReq, FlushCacheReq,
GenerateReqInput, GenerateReqInput,
...@@ -109,6 +112,7 @@ class TokenizerManager: ...@@ -109,6 +112,7 @@ class TokenizerManager:
# Parse args # Parse args
self.server_args = server_args self.server_args = server_args
self.enable_metrics = server_args.enable_metrics self.enable_metrics = server_args.enable_metrics
self.log_requests = server_args.log_requests
# Init inter-process communication # Init inter-process communication
context = zmq.asyncio.Context(2) context = zmq.asyncio.Context(2)
...@@ -167,6 +171,9 @@ class TokenizerManager: ...@@ -167,6 +171,9 @@ class TokenizerManager:
# Store states # Store states
self.to_create_loop = True self.to_create_loop = True
self.rid_to_state: Dict[str, ReqState] = {} self.rid_to_state: Dict[str, ReqState] = {}
self.dump_requests_folder = "" # By default do not dump
self.dump_requests_threshold = 1000
self.dump_request_list: List[Tuple] = []
# The event to notify the weight sync is finished. # The event to notify the weight sync is finished.
self.model_update_lock = RWLock() self.model_update_lock = RWLock()
...@@ -225,7 +232,7 @@ class TokenizerManager: ...@@ -225,7 +232,7 @@ class TokenizerManager:
obj.normalize_batch_and_arguments() obj.normalize_batch_and_arguments()
if self.server_args.log_requests: if self.log_requests:
logger.info(f"Receive: obj={dataclass_to_string_truncated(obj)}") logger.info(f"Receive: obj={dataclass_to_string_truncated(obj)}")
async with self.model_update_lock.reader_lock: async with self.model_update_lock.reader_lock:
...@@ -346,7 +353,7 @@ class TokenizerManager: ...@@ -346,7 +353,7 @@ class TokenizerManager:
state.out_list = [] state.out_list = []
if state.finished: if state.finished:
if self.server_args.log_requests: if self.log_requests:
msg = f"Finish: obj={dataclass_to_string_truncated(obj)}, out={dataclass_to_string_truncated(out)}" msg = f"Finish: obj={dataclass_to_string_truncated(obj)}, out={dataclass_to_string_truncated(out)}"
logger.info(msg) logger.info(msg)
del self.rid_to_state[obj.rid] del self.rid_to_state[obj.rid]
...@@ -597,6 +604,15 @@ class TokenizerManager: ...@@ -597,6 +604,15 @@ class TokenizerManager:
assert not self.to_create_loop, "close session should not be the first request" assert not self.to_create_loop, "close session should not be the first request"
await self.send_to_scheduler.send_pyobj(obj) await self.send_to_scheduler.send_pyobj(obj)
def configure_logging(self, obj: ConfigureLoggingReq):
if obj.log_requests is not None:
self.log_requests = obj.log_requests
if obj.dump_requests_folder is not None:
self.dump_requests_folder = obj.dump_requests_folder
if obj.dump_requests_threshold is not None:
self.dump_requests_threshold = obj.dump_requests_threshold
logging.info(f"Config logging: {obj=}")
def create_abort_task(self, obj: GenerateReqInput): def create_abort_task(self, obj: GenerateReqInput):
# Abort the request if the client is disconnected. # Abort the request if the client is disconnected.
async def abort_request(): async def abort_request():
...@@ -708,6 +724,8 @@ class TokenizerManager: ...@@ -708,6 +724,8 @@ class TokenizerManager:
if self.enable_metrics: if self.enable_metrics:
self.collect_metrics(state, recv_obj, i) self.collect_metrics(state, recv_obj, i)
if self.dump_requests_folder and state.finished:
self.dump_requests(state, out_dict)
elif isinstance(recv_obj, OpenSessionReqOutput): elif isinstance(recv_obj, OpenSessionReqOutput):
self.session_futures[recv_obj.session_id].set_result( self.session_futures[recv_obj.session_id].set_result(
recv_obj.session_id if recv_obj.success else None recv_obj.session_id if recv_obj.success else None
...@@ -850,6 +868,25 @@ class TokenizerManager: ...@@ -850,6 +868,25 @@ class TokenizerManager:
(time.time() - state.created_time) / completion_tokens (time.time() - state.created_time) / completion_tokens
) )
def dump_requests(self, state: ReqState, out_dict: dict):
self.dump_request_list.append(
(state.obj, out_dict, state.created_time, time.time())
)
if len(self.dump_request_list) >= self.dump_requests_threshold:
to_dump = self.dump_request_list
self.dump_request_list = []
def background_task():
os.makedirs(self.dump_requests_folder, exist_ok=True)
current_time = datetime.now()
filename = current_time.strftime("%Y-%m-%d_%H-%M-%S") + ".pkl"
with open(os.path.join(self.dump_requests_folder, filename), "wb") as f:
pickle.dump(to_dump, f)
# Schedule the task to run in the background without awaiting it
asyncio.create_task(asyncio.to_thread(background_task))
class SignalHandler: class SignalHandler:
def __init__(self, tokenizer_manager): def __init__(self, tokenizer_manager):
......
...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
""" """
from sglang.torch_memory_saver_adapter import TorchMemorySaverAdapter from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
""" """
Memory pool. Memory pool.
......
...@@ -50,6 +50,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch ...@@ -50,6 +50,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader import get_model from sglang.srt.model_loader import get_model
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.srt.utils import ( from sglang.srt.utils import (
enable_show_time_cost, enable_show_time_cost,
get_available_gpu_memory, get_available_gpu_memory,
...@@ -60,7 +61,6 @@ from sglang.srt.utils import ( ...@@ -60,7 +61,6 @@ from sglang.srt.utils import (
monkey_patch_vllm_p2p_access_check, monkey_patch_vllm_p2p_access_check,
set_cpu_offload_max_bytes, set_cpu_offload_max_bytes,
) )
from sglang.torch_memory_saver_adapter import TorchMemorySaverAdapter
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
...@@ -31,7 +31,7 @@ from typing import AsyncIterator, Dict, List, Optional, Tuple, Union ...@@ -31,7 +31,7 @@ from typing import AsyncIterator, Dict, List, Optional, Tuple, Union
import torch import torch
from sglang.torch_memory_saver_adapter import TorchMemorySaverAdapter from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
# Fix a bug of Python threading # Fix a bug of Python threading
setattr(threading, "_register_atexit", lambda *args, **kwargs: None) setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
...@@ -54,6 +54,7 @@ from sglang.srt.managers.data_parallel_controller import ( ...@@ -54,6 +54,7 @@ from sglang.srt.managers.data_parallel_controller import (
from sglang.srt.managers.detokenizer_manager import run_detokenizer_process from sglang.srt.managers.detokenizer_manager import run_detokenizer_process
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
CloseSessionReqInput, CloseSessionReqInput,
ConfigureLoggingReq,
EmbeddingReqInput, EmbeddingReqInput,
GenerateReqInput, GenerateReqInput,
GetWeightsByNameReqInput, GetWeightsByNameReqInput,
...@@ -161,12 +162,68 @@ async def get_model_info(): ...@@ -161,12 +162,68 @@ async def get_model_info():
@app.get("/get_server_info") @app.get("/get_server_info")
async def get_server_info(): async def get_server_info():
return { return {
**dataclasses.asdict(tokenizer_manager.server_args), # server args **dataclasses.asdict(tokenizer_manager.server_args),
**scheduler_info, **scheduler_info,
"version": __version__, "version": __version__,
} }
# fastapi implicitly converts json in the request to obj (dataclass)
@app.api_route("/generate", methods=["POST", "PUT"])
@time_func_latency
async def generate_request(obj: GenerateReqInput, request: Request):
"""Handle a generate request."""
if obj.stream:
async def stream_results() -> AsyncIterator[bytes]:
try:
async for out in tokenizer_manager.generate_request(obj, request):
yield b"data: " + orjson.dumps(
out, option=orjson.OPT_NON_STR_KEYS
) + b"\n\n"
except ValueError as e:
out = {"error": {"message": str(e)}}
yield b"data: " + orjson.dumps(
out, option=orjson.OPT_NON_STR_KEYS
) + b"\n\n"
yield b"data: [DONE]\n\n"
return StreamingResponse(
stream_results(),
media_type="text/event-stream",
background=tokenizer_manager.create_abort_task(obj),
)
else:
try:
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
return ret
except ValueError as e:
logger.error(f"Error: {e}")
return _create_error_response(e)
@app.api_route("/encode", methods=["POST", "PUT"])
@time_func_latency
async def encode_request(obj: EmbeddingReqInput, request: Request):
"""Handle an embedding request."""
try:
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
return ret
except ValueError as e:
return _create_error_response(e)
@app.api_route("/classify", methods=["POST", "PUT"])
@time_func_latency
async def classify_request(obj: EmbeddingReqInput, request: Request):
"""Handle a reward model request. Now the arguments and return values are the same as embedding models."""
try:
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
return ret
except ValueError as e:
return _create_error_response(e)
@app.post("/flush_cache") @app.post("/flush_cache")
async def flush_cache(): async def flush_cache():
"""Flush the radix cache.""" """Flush the radix cache."""
...@@ -178,8 +235,7 @@ async def flush_cache(): ...@@ -178,8 +235,7 @@ async def flush_cache():
) )
@app.get("/start_profile") @app.api_route("/start_profile", methods=["GET", "POST"])
@app.post("/start_profile")
async def start_profile_async(): async def start_profile_async():
"""Start profiling.""" """Start profiling."""
tokenizer_manager.start_profile() tokenizer_manager.start_profile()
...@@ -189,8 +245,7 @@ async def start_profile_async(): ...@@ -189,8 +245,7 @@ async def start_profile_async():
) )
@app.get("/stop_profile") @app.api_route("/stop_profile", methods=["GET", "POST"])
@app.post("/stop_profile")
async def stop_profile_async(): async def stop_profile_async():
"""Stop profiling.""" """Stop profiling."""
tokenizer_manager.stop_profile() tokenizer_manager.stop_profile()
...@@ -305,60 +360,11 @@ async def close_session(obj: CloseSessionReqInput, request: Request): ...@@ -305,60 +360,11 @@ async def close_session(obj: CloseSessionReqInput, request: Request):
return _create_error_response(e) return _create_error_response(e)
# fastapi implicitly converts json in the request to obj (dataclass) @app.api_route("/configure_logging", methods=["GET", "POST"])
@app.api_route("/generate", methods=["POST", "PUT"]) async def configure_logging(obj: ConfigureLoggingReq, request: Request):
@time_func_latency """Close the session"""
async def generate_request(obj: GenerateReqInput, request: Request): tokenizer_manager.configure_logging(obj)
"""Handle a generate request.""" return Response(status_code=200)
if obj.stream:
async def stream_results() -> AsyncIterator[bytes]:
try:
async for out in tokenizer_manager.generate_request(obj, request):
yield b"data: " + orjson.dumps(
out, option=orjson.OPT_NON_STR_KEYS
) + b"\n\n"
except ValueError as e:
out = {"error": {"message": str(e)}}
yield b"data: " + orjson.dumps(
out, option=orjson.OPT_NON_STR_KEYS
) + b"\n\n"
yield b"data: [DONE]\n\n"
return StreamingResponse(
stream_results(),
media_type="text/event-stream",
background=tokenizer_manager.create_abort_task(obj),
)
else:
try:
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
return ret
except ValueError as e:
logger.error(f"Error: {e}")
return _create_error_response(e)
@app.api_route("/encode", methods=["POST", "PUT"])
@time_func_latency
async def encode_request(obj: EmbeddingReqInput, request: Request):
"""Handle an embedding request."""
try:
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
return ret
except ValueError as e:
return _create_error_response(e)
@app.api_route("/classify", methods=["POST", "PUT"])
@time_func_latency
async def classify_request(obj: EmbeddingReqInput, request: Request):
"""Handle a reward model request. Now the arguments and return values are the same as embedding models."""
try:
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
return ret
except ValueError as e:
return _create_error_response(e)
##### OpenAI-compatible API endpoints ##### ##### OpenAI-compatible API endpoints #####
......
...@@ -91,7 +91,7 @@ class ServerArgs: ...@@ -91,7 +91,7 @@ class ServerArgs:
# API related # API related
api_key: Optional[str] = None api_key: Optional[str] = None
file_storage_pth: str = "SGLang_storage" file_storage_pth: str = "sglang_storage"
enable_cache_report: bool = False enable_cache_report: bool = False
# Data parallelism # Data parallelism
...@@ -554,7 +554,7 @@ class ServerArgs: ...@@ -554,7 +554,7 @@ class ServerArgs:
"--decode-log-interval", "--decode-log-interval",
type=int, type=int,
default=ServerArgs.decode_log_interval, default=ServerArgs.decode_log_interval,
help="The log interval of decode batch", help="The log interval of decode batch.",
) )
# API related # API related
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment