replay_request_dump.py 4.72 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
"""
Usage:
# replay from a folder
python3 replay_request_dump.py --file-number 100 --parallel 512 --input-folder /data/lianmin/sglang_request_dump/grok-mini-0220-engine-5756f8f94-28bm6/

# replay from a single file
python3 replay_request_dump.py --parallel 512 --input-file /data/sglang_crash_dump/memx-cti-34-sr1.xpop.twttr.net/crash_dump_2025-06-04_20-13-18.pkl
"""

import argparse
import glob
import json
import pickle
import time
from concurrent.futures import ThreadPoolExecutor
from dataclasses import asdict
from datetime import datetime

import requests

from sglang.bench_serving import set_ulimit
from sglang.utils import get_exception_traceback


def read_records(files):
    records = []
    for f in files:
        tmp = pickle.load(open(f, "rb"))
        if isinstance(tmp, dict) and "requests" in tmp:
            records.extend(tmp["requests"])
        else:
            records.extend(tmp)

    return records


def run_one_request_internal(record):
    (req, output, replay_init_time, start_time, end_time, idx) = record
39
    time.sleep(max(0, (start_time - (time.time() - replay_init_time)) / args.speed))
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123

    if "completion_tokens" in output.get("meta_info", {}):
        recorded_completion_tokens = output["meta_info"]["completion_tokens"]
    else:
        recorded_completion_tokens = ""

    json_data = asdict(req)
    stream = json_data["stream"]

    if args.ignore_eos:
        json_data["sampling_params"]["ignore_eos"] = True
        if recorded_completion_tokens:
            json_data["sampling_params"]["max_new_tokens"] = recorded_completion_tokens

    response = requests.post(
        f"http://{args.host}:{args.port}/generate",
        json=json_data,
        stream=stream,
    )

    if stream:
        for chunk in response.iter_lines(decode_unicode=False):
            chunk = chunk.decode("utf-8")
            if chunk and chunk.startswith("data:"):
                if chunk == "data: [DONE]":
                    break
                ret = json.loads(chunk[5:].strip("\n"))
    else:
        ret = response.json()

    prompt_tokens = ret["meta_info"]["prompt_tokens"]
    completion_tokens = ret["meta_info"]["completion_tokens"]
    print(
        f"{idx=}, {start_time=:.2f}, {prompt_tokens=}, "
        f"{completion_tokens=}, {recorded_completion_tokens=}"
    )


def run_one_request(record):
    # global success_ct, error_ct

    try:
        run_one_request_internal(record)
        # success_ct += 1
    except Exception:
        # error_ct += 1
        traceback = get_exception_traceback()
        print(f"Hit an exception: {traceback}")


def main(records):
    if len(records) == 0:
        return

    base_time = records[0][-2]
    base_time_str = datetime.fromtimestamp(base_time).strftime("%y-%m-%d %H:%M:%S")
    print(f"{base_time_str=}")
    replay_init_time = time.time()

    for i in range(len(records)):
        req, output, start_time, end_time = records[i]
        start_time -= base_time
        records[i] = (req, output, replay_init_time, start_time, end_time, i)

    with ThreadPoolExecutor(args.parallel) as executor:
        executor.map(run_one_request, records)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--host", type=str, default="localhost")
    parser.add_argument("--port", type=int, default=30000)
    parser.add_argument(
        "--input-folder", type=str, default=None, help="Folder containing pickle files"
    )
    parser.add_argument(
        "--input-file", type=str, default=None, help="Single pickle file to process"
    )
    parser.add_argument("--file-number", type=int, default=1)
    parser.add_argument("--req-number", type=int, default=1000000)
    parser.add_argument("--req-start", type=int, default=0)
    parser.add_argument("--parallel", type=int, default=512)
    parser.add_argument("--idx", type=int, default=None)
    parser.add_argument("--ignore-eos", action="store_true")
124
    parser.add_argument("--speed", type=float, default=1)
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
    args = parser.parse_args()

    set_ulimit()

    files = []
    if args.input_file:
        files = [args.input_file]
        if args.file_number > 1:
            print("Warning: --file-number is ignored when --input-file is provided.")
    elif args.input_folder:
        files = glob.glob(f"{args.input_folder}/*.pkl")
        files = files[: args.file_number]
    else:
        print("Error: Either --input-folder or --input-file must be provided.")
        exit(1)
    print(f"{files=}")

    records = read_records(files)
    # Sort by the receive time, before filtering
    records.sort(key=lambda x: x[-2])
    records = records[args.req_start :]
    if args.idx:
        records = [records[args.idx]]
        print(f"testing {args.idx=}")
        print(f"{records[0]}")
    print(f"{len(records)=}")
    main(records)