Unverified Commit 3b141e15 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Dump requests (#2862)

parent 6249e4a1
......@@ -18,10 +18,12 @@ import copy
import dataclasses
import logging
import os
import pickle
import signal
import sys
import time
import uuid
from datetime import datetime
from typing import Any, Awaitable, Dict, Generic, List, Optional, Tuple, TypeVar, Union
import fastapi
......@@ -105,6 +107,7 @@ class TokenizerManager:
# Parse args
self.server_args = server_args
self.enable_metrics = server_args.enable_metrics
self.dump_requsts_folder = server_args.dump_requests_folder
# Init inter-process communication
context = zmq.asyncio.Context(2)
......@@ -163,6 +166,7 @@ class TokenizerManager:
# Store states
self.to_create_loop = True
self.rid_to_state: Dict[str, ReqState] = {}
self.dump_request_list: List[Tuple] = []
# The event to notify the weight sync is finished.
self.model_update_lock = RWLock()
......@@ -680,6 +684,9 @@ class TokenizerManager:
if self.enable_metrics:
self.collect_metrics(state, recv_obj, i)
if self.dump_requsts_folder and state.finished:
self.dump_requests(state, out_dict)
elif isinstance(recv_obj, OpenSessionReqOutput):
self.session_futures[recv_obj.session_id].set_result(
recv_obj.session_id if recv_obj.success else None
......@@ -818,6 +825,27 @@ class TokenizerManager:
(time.time() - state.created_time) / completion_tokens
)
def dump_requests(self, state: ReqState, out_dict: dict):
self.dump_request_list.append(
(state.obj, out_dict, state.created_time, time.time())
)
if len(self.dump_request_list) > int(
os.environ.get("SGLANG_DUMP_REQUESTS_THRESHOLD", "1000")
):
to_dump = self.dump_request_list
self.dump_request_list = []
def background_task():
os.makedirs(self.dump_requsts_folder, exist_ok=True)
current_time = datetime.now()
filename = current_time.strftime("%Y-%m-%d_%H-%M-%S") + ".pkl"
with open(os.path.join(self.dump_requsts_folder, filename), "wb") as f:
pickle.dump(to_dump, f)
# Schedule the task to run in the background without awaiting it
asyncio.create_task(asyncio.to_thread(background_task))
class SignalHandler:
def __init__(self, tokenizer_manager):
......
......@@ -23,7 +23,6 @@ from typing import List, Optional
import torch
from sglang.srt.hf_transformers_utils import check_gguf_file
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.utils import (
get_amdgpu_memory_capacity,
get_hpu_memory_capacity,
......@@ -89,6 +88,7 @@ class ServerArgs:
show_time_cost: bool = False
enable_metrics: bool = False
decode_log_interval: int = 40
dump_requests_folder: str = ""
# API related
api_key: Optional[str] = None
......@@ -554,7 +554,13 @@ class ServerArgs:
"--decode-log-interval",
type=int,
default=ServerArgs.decode_log_interval,
help="The log interval of decode batch",
help="The log interval of decode batch.",
)
parser.add_argument(
"--dump-requests-folder",
type=str,
default=ServerArgs.decode_log_interval,
help="Dump raw requests to a folder for replay.",
)
# API related
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment