"vscode:/vscode.git/clone" did not exist on "11e910b6f88ec10c7549493dab9e2b218e988e97"
Unverified Commit 46d44318 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Add a new api configure_logging to allow dumping the requests (#2875)

parent 923f5183
......@@ -336,7 +336,7 @@ loadTracer.sh python3 -m sglang.launch_server \
--model-path /sgl-workspace/sglang/dummy_grok1 \
--tokenizer-path Xenova/grok-1-tokenizer \
--load-format dummy \
--quant fp8 \
--quantization fp8 \
--tp 8 \
--port 30000 \
--disable-radix-cache 2>&1 | tee "$LOGFILE"
......
......@@ -14,7 +14,7 @@ loadTracer.sh python3 -m sglang.launch_server \
--model-path /sgl-workspace/sglang/dummy_grok1 \
--tokenizer-path Xenova/grok-1-tokenizer \
--load-format dummy \
--quant fp8 \
--quantization fp8 \
--tp 8 \
--port 30000 \
--disable-radix-cache 2>&1 | tee "$LOGFILE"
......@@ -104,7 +104,7 @@ To maximize moe kernel efficiency, need to use below scripts to find out the bes
```bash
#Tuning
#for example, we have one case like this "python3 -m sglang.bench_latency --model dummy_grok1/ --load-format dummy --tokenizer-path Xenova/grok-1-tokenizer --tp 8 --batch-size 32 --input 1024 --output 8 --attention-backend triton --sampling-backend pytorch --quant fp" to run, it defined batch-size 32 input lenth 1024 and output length 8, from "--batch" in moe view point, the prefill batch is 32*1024 = 32768, the decode batch is 32*1(only one output token generated in each run).
#for example, we have one case like this "python3 -m sglang.bench_latency --model dummy_grok1/ --load-format dummy --tokenizer-path Xenova/grok-1-tokenizer --tp 8 --batch-size 32 --input 1024 --output 8 --attention-backend triton --sampling-backend pytorch --quantization fp8" to run, it defined batch-size 32 input lenth 1024 and output length 8, from "--batch" in moe view point, the prefill batch is 32*1024 = 32768, the decode batch is 32*1(only one output token generated in each run).
#so we can tune decode moe use below command
python benchmark_moe_rocm.py --model grok1 --tp-size 8 --dtype float8 --batch "32"
# and use this command to tune prefill moe
......
......@@ -6,7 +6,7 @@
# wget https://huggingface.co/neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8/resolve/main/tokenizer_config.json
# Launch sglang
# python -m sglang.launch_server --model-path ~/llama-3.1-405b-fp8-dummy/ --load-format dummy --tp 8 --quant fp8 --disable-radix --mem-frac 0.87
# python -m sglang.launch_server --model-path ~/llama-3.1-405b-fp8-dummy/ --load-format dummy --tp 8 --quantization fp8 --disable-radix --mem-frac 0.87
# offline
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 3000 --random-input 1024 --random-output 1024 > sglang_log11
......
"""
Copyright 2023-2025 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
"""
Configure the logging settings of a server.
Usage:
python3 -m sglang.srt.managers.configure_logging --url http://localhost:30000
"""
import argparse
import requests
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--url", type=str, default="http://localhost:30000")
parser.add_argument(
"--dump-requests-folder", type=str, default="/tmp/sglang_request_dump"
)
parser.add_argument("--dump-requests-threshold", type=int, default=1000)
args = parser.parse_args()
response = requests.post(
args.url + "/configure_logging",
json={
"dump_requests_folder": args.dump_requests_folder,
"dump_requests_threshold": args.dump_requests_threshold,
},
)
assert response.status_code == 200
......@@ -488,6 +488,13 @@ class ProfileReq(Enum):
STOP_PROFILE = 2
@dataclass
class ConfigureLoggingReq:
log_requests: Optional[bool] = None
dump_requests_folder: Optional[str] = None
dump_requests_threshold: Optional[int] = None
@dataclass
class OpenSessionReqInput:
capacity_of_str_len: int
......
......@@ -82,6 +82,7 @@ from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerSta
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.srt.utils import (
broadcast_pyobj,
configure_logger,
......@@ -92,7 +93,6 @@ from sglang.srt.utils import (
set_random_seed,
suppress_other_loggers,
)
from sglang.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.utils import get_exception_traceback
logger = logging.getLogger(__name__)
......
......@@ -18,10 +18,12 @@ import copy
import dataclasses
import logging
import os
import pickle
import signal
import sys
import time
import uuid
from datetime import datetime
from typing import Any, Awaitable, Dict, Generic, List, Optional, Tuple, TypeVar, Union
import fastapi
......@@ -43,6 +45,7 @@ from sglang.srt.managers.io_struct import (
BatchStrOut,
BatchTokenIDOut,
CloseSessionReqInput,
ConfigureLoggingReq,
EmbeddingReqInput,
FlushCacheReq,
GenerateReqInput,
......@@ -109,6 +112,7 @@ class TokenizerManager:
# Parse args
self.server_args = server_args
self.enable_metrics = server_args.enable_metrics
self.log_requests = server_args.log_requests
# Init inter-process communication
context = zmq.asyncio.Context(2)
......@@ -167,6 +171,9 @@ class TokenizerManager:
# Store states
self.to_create_loop = True
self.rid_to_state: Dict[str, ReqState] = {}
self.dump_requests_folder = "" # By default do not dump
self.dump_requests_threshold = 1000
self.dump_request_list: List[Tuple] = []
# The event to notify the weight sync is finished.
self.model_update_lock = RWLock()
......@@ -225,7 +232,7 @@ class TokenizerManager:
obj.normalize_batch_and_arguments()
if self.server_args.log_requests:
if self.log_requests:
logger.info(f"Receive: obj={dataclass_to_string_truncated(obj)}")
async with self.model_update_lock.reader_lock:
......@@ -346,7 +353,7 @@ class TokenizerManager:
state.out_list = []
if state.finished:
if self.server_args.log_requests:
if self.log_requests:
msg = f"Finish: obj={dataclass_to_string_truncated(obj)}, out={dataclass_to_string_truncated(out)}"
logger.info(msg)
del self.rid_to_state[obj.rid]
......@@ -597,6 +604,15 @@ class TokenizerManager:
assert not self.to_create_loop, "close session should not be the first request"
await self.send_to_scheduler.send_pyobj(obj)
def configure_logging(self, obj: ConfigureLoggingReq):
if obj.log_requests is not None:
self.log_requests = obj.log_requests
if obj.dump_requests_folder is not None:
self.dump_requests_folder = obj.dump_requests_folder
if obj.dump_requests_threshold is not None:
self.dump_requests_threshold = obj.dump_requests_threshold
logging.info(f"Config logging: {obj=}")
def create_abort_task(self, obj: GenerateReqInput):
# Abort the request if the client is disconnected.
async def abort_request():
......@@ -708,6 +724,8 @@ class TokenizerManager:
if self.enable_metrics:
self.collect_metrics(state, recv_obj, i)
if self.dump_requests_folder and state.finished:
self.dump_requests(state, out_dict)
elif isinstance(recv_obj, OpenSessionReqOutput):
self.session_futures[recv_obj.session_id].set_result(
recv_obj.session_id if recv_obj.success else None
......@@ -850,6 +868,25 @@ class TokenizerManager:
(time.time() - state.created_time) / completion_tokens
)
def dump_requests(self, state: ReqState, out_dict: dict):
self.dump_request_list.append(
(state.obj, out_dict, state.created_time, time.time())
)
if len(self.dump_request_list) >= self.dump_requests_threshold:
to_dump = self.dump_request_list
self.dump_request_list = []
def background_task():
os.makedirs(self.dump_requests_folder, exist_ok=True)
current_time = datetime.now()
filename = current_time.strftime("%Y-%m-%d_%H-%M-%S") + ".pkl"
with open(os.path.join(self.dump_requests_folder, filename), "wb") as f:
pickle.dump(to_dump, f)
# Schedule the task to run in the background without awaiting it
asyncio.create_task(asyncio.to_thread(background_task))
class SignalHandler:
def __init__(self, tokenizer_manager):
......
......@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
from sglang.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
"""
Memory pool.
......
......@@ -50,6 +50,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader import get_model
from sglang.srt.server_args import ServerArgs
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.srt.utils import (
enable_show_time_cost,
get_available_gpu_memory,
......@@ -60,7 +61,6 @@ from sglang.srt.utils import (
monkey_patch_vllm_p2p_access_check,
set_cpu_offload_max_bytes,
)
from sglang.torch_memory_saver_adapter import TorchMemorySaverAdapter
logger = logging.getLogger(__name__)
......
......@@ -31,7 +31,7 @@ from typing import AsyncIterator, Dict, List, Optional, Tuple, Union
import torch
from sglang.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
# Fix a bug of Python threading
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
......@@ -54,6 +54,7 @@ from sglang.srt.managers.data_parallel_controller import (
from sglang.srt.managers.detokenizer_manager import run_detokenizer_process
from sglang.srt.managers.io_struct import (
CloseSessionReqInput,
ConfigureLoggingReq,
EmbeddingReqInput,
GenerateReqInput,
GetWeightsByNameReqInput,
......@@ -161,12 +162,68 @@ async def get_model_info():
@app.get("/get_server_info")
async def get_server_info():
return {
**dataclasses.asdict(tokenizer_manager.server_args), # server args
**dataclasses.asdict(tokenizer_manager.server_args),
**scheduler_info,
"version": __version__,
}
# fastapi implicitly converts json in the request to obj (dataclass)
@app.api_route("/generate", methods=["POST", "PUT"])
@time_func_latency
async def generate_request(obj: GenerateReqInput, request: Request):
"""Handle a generate request."""
if obj.stream:
async def stream_results() -> AsyncIterator[bytes]:
try:
async for out in tokenizer_manager.generate_request(obj, request):
yield b"data: " + orjson.dumps(
out, option=orjson.OPT_NON_STR_KEYS
) + b"\n\n"
except ValueError as e:
out = {"error": {"message": str(e)}}
yield b"data: " + orjson.dumps(
out, option=orjson.OPT_NON_STR_KEYS
) + b"\n\n"
yield b"data: [DONE]\n\n"
return StreamingResponse(
stream_results(),
media_type="text/event-stream",
background=tokenizer_manager.create_abort_task(obj),
)
else:
try:
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
return ret
except ValueError as e:
logger.error(f"Error: {e}")
return _create_error_response(e)
@app.api_route("/encode", methods=["POST", "PUT"])
@time_func_latency
async def encode_request(obj: EmbeddingReqInput, request: Request):
"""Handle an embedding request."""
try:
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
return ret
except ValueError as e:
return _create_error_response(e)
@app.api_route("/classify", methods=["POST", "PUT"])
@time_func_latency
async def classify_request(obj: EmbeddingReqInput, request: Request):
"""Handle a reward model request. Now the arguments and return values are the same as embedding models."""
try:
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
return ret
except ValueError as e:
return _create_error_response(e)
@app.post("/flush_cache")
async def flush_cache():
"""Flush the radix cache."""
......@@ -178,8 +235,7 @@ async def flush_cache():
)
@app.get("/start_profile")
@app.post("/start_profile")
@app.api_route("/start_profile", methods=["GET", "POST"])
async def start_profile_async():
"""Start profiling."""
tokenizer_manager.start_profile()
......@@ -189,8 +245,7 @@ async def start_profile_async():
)
@app.get("/stop_profile")
@app.post("/stop_profile")
@app.api_route("/stop_profile", methods=["GET", "POST"])
async def stop_profile_async():
"""Stop profiling."""
tokenizer_manager.stop_profile()
......@@ -305,60 +360,11 @@ async def close_session(obj: CloseSessionReqInput, request: Request):
return _create_error_response(e)
# fastapi implicitly converts json in the request to obj (dataclass)
@app.api_route("/generate", methods=["POST", "PUT"])
@time_func_latency
async def generate_request(obj: GenerateReqInput, request: Request):
"""Handle a generate request."""
if obj.stream:
async def stream_results() -> AsyncIterator[bytes]:
try:
async for out in tokenizer_manager.generate_request(obj, request):
yield b"data: " + orjson.dumps(
out, option=orjson.OPT_NON_STR_KEYS
) + b"\n\n"
except ValueError as e:
out = {"error": {"message": str(e)}}
yield b"data: " + orjson.dumps(
out, option=orjson.OPT_NON_STR_KEYS
) + b"\n\n"
yield b"data: [DONE]\n\n"
return StreamingResponse(
stream_results(),
media_type="text/event-stream",
background=tokenizer_manager.create_abort_task(obj),
)
else:
try:
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
return ret
except ValueError as e:
logger.error(f"Error: {e}")
return _create_error_response(e)
@app.api_route("/encode", methods=["POST", "PUT"])
@time_func_latency
async def encode_request(obj: EmbeddingReqInput, request: Request):
"""Handle an embedding request."""
try:
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
return ret
except ValueError as e:
return _create_error_response(e)
@app.api_route("/classify", methods=["POST", "PUT"])
@time_func_latency
async def classify_request(obj: EmbeddingReqInput, request: Request):
"""Handle a reward model request. Now the arguments and return values are the same as embedding models."""
try:
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
return ret
except ValueError as e:
return _create_error_response(e)
@app.api_route("/configure_logging", methods=["GET", "POST"])
async def configure_logging(obj: ConfigureLoggingReq, request: Request):
"""Close the session"""
tokenizer_manager.configure_logging(obj)
return Response(status_code=200)
##### OpenAI-compatible API endpoints #####
......
......@@ -91,7 +91,7 @@ class ServerArgs:
# API related
api_key: Optional[str] = None
file_storage_pth: str = "SGLang_storage"
file_storage_pth: str = "sglang_storage"
enable_cache_report: bool = False
# Data parallelism
......@@ -554,7 +554,7 @@ class ServerArgs:
"--decode-log-interval",
type=int,
default=ServerArgs.decode_log_interval,
help="The log interval of decode batch",
help="The log interval of decode batch.",
)
# API related
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment