real_data_benchmark.py 11 KB
Newer Older
1
2
#!/usr/bin/env python3

3
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# SPDX-License-Identifier: Apache-2.0

import argparse
import json
import logging
import os
import subprocess

import numpy as np
from prefix_data_generator.synthesizer import Synthesizer

# Setup logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
formatter = logging.Formatter(
    "%(asctime)s - %(name)s - %(levelname)s - %(message)s", "%Y-%m-%d %H:%M:%S"
)
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)


27
def get_aiperf_cmd_for_trace(
28
29
    model,
    tokenizer,
30
    input_dataset,
31
32
33
34
35
    artifact_dir,
    seed,
    url="http://localhost:8888",
):
    return [
36
        "aiperf",
37
38
39
40
41
42
43
44
45
46
47
48
49
        "profile",
        "--model",
        model,
        "--tokenizer",
        tokenizer,
        "--endpoint-type",
        "chat",
        "--endpoint",
        "v1/chat/completions",
        "--streaming",
        "--url",
        url,
        "--input-file",
50
        f"{input_dataset}",
51
52
        "--custom-dataset-type",
        "mooncake_trace",
53
        "--fixed-schedule-auto-offset",
54
55
56
57
58
59
60
61
62
63
64
65
66
67
        "--random-seed",
        str(seed),
        "--artifact-dir",
        artifact_dir,
        "-H",
        "Authorization: Bearer NOT USED",
        "-H",
        "Accept: text/event-stream",
    ]


def run_benchmark_with_trace(
    model,
    tokenizer,
68
    trace_dataset,
69
70
71
72
    artifact_dir,
    url,
    seed,
):
73
74
    """Run aiperf benchmark with a trace dataset"""
    aiperf_cmd = get_aiperf_cmd_for_trace(
75
76
        model,
        tokenizer,
77
        trace_dataset,
78
79
80
81
82
        artifact_dir,
        seed,
        url,
    )

83
84
    logger.info(f"Running aiperf with trace dataset: {trace_dataset}")
    logger.info(f"Command: {' '.join(aiperf_cmd)}")
85
86

    try:
87
88
        # Run aiperf and let it output directly to terminal
        subprocess.run(aiperf_cmd, check=True)
89

90
        logger.info("AIPerf profiling completed successfully")
91
92

    except subprocess.CalledProcessError as e:
93
        logger.error(f"AIPerf failed with error code: {e.returncode}")
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
        logger.error(f"stderr: {e.stderr}")
        raise


def main():
    parser = argparse.ArgumentParser(
        description="Benchmark with real or synthesized mooncake-style trace data"
    )

    # Model and server configuration
    parser.add_argument(
        "--model",
        type=str,
        default="deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
        help="Model name",
    )
    parser.add_argument(
        "--tokenizer",
        type=str,
        default=None,
        help="Tokenizer name (defaults to model)",
    )
    parser.add_argument(
        "--url",
        type=str,
Yan Ru Pei's avatar
Yan Ru Pei committed
119
        default="http://localhost:8000",
120
121
122
123
124
125
126
127
128
        help="Server URL",
    )
    parser.add_argument(
        "--output-dir",
        type=str,
        default="real_data_benchmark_results",
        help="Output directory for results",
    )

129
    # Trace dataset and synthesis configuration (similar to synthesizer.py)
130
    parser.add_argument(
131
        "--input-dataset",
132
133
        type=str,
        default="mooncake_trace.jsonl",
134
        help="Path to the input mooncake-style trace dataset file",
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
    )
    parser.add_argument(
        "--num-requests",
        type=int,
        default=None,
        help="Number of requests to synthesize (default: use all from input file)",
    )
    parser.add_argument(
        "--speedup-ratio",
        type=float,
        default=1.0,
        help="Factor to speed up request intervals (default: 1.0)",
    )
    parser.add_argument(
        "--prefix-len-multiplier",
        type=float,
        default=1.0,
        help="Multiplier for prefix lengths (default: 1.0)",
    )
    parser.add_argument(
        "--prefix-root-multiplier",
        type=int,
        default=1,
        help="Number of times to replicate the core radix tree (default: 1)",
    )
    parser.add_argument(
        "--prompt-len-multiplier",
        type=float,
        default=1.0,
        help="Multiplier for leaf path lengths (default: 1.0, use <1 for shorter prompts)",
    )
    parser.add_argument(
        "--max-isl",
        type=int,
        default=None,
        help="Maximum input sequence length to include in output (default: None, no filtering)",
    )
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
    parser.add_argument(
        "--min-isl",
        type=int,
        default=None,
        help="Minimum input sequence length to include in output (default: None, no filtering)",
    )
    parser.add_argument(
        "--min-osl",
        type=int,
        default=None,
        help="Minimum output sequence length - clips values below this threshold (default: None, no clipping)",
    )
    parser.add_argument(
        "--max-osl",
        type=int,
        default=None,
        help="Maximum output sequence length - clips values above this threshold (default: None, no clipping)",
    )
190
191
192
193
194
195
196
197
198
199
200
201
    parser.add_argument(
        "--block-size",
        type=int,
        default=512,
        help="Block size for prefilling and decoding (default: 512)",
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=0,
        help="Random seed for reproducibility (default: 0)",
    )
202
203
204
205
206
    parser.add_argument(
        "--use-expected-osl",
        action="store_true",
        help="Pass expected_output_tokens to nvext for router tracking",
    )
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225

    args = parser.parse_args()

    # Use tokenizer from model if not specified
    if args.tokenizer is None:
        args.tokenizer = args.model

    # Create output directory
    os.makedirs(args.output_dir, exist_ok=True)

    # Determine whether to use original or synthesized data
    # Check if any synthesis parameters are non-default
    needs_synthesis = (
        args.num_requests is not None
        or args.speedup_ratio != 1.0
        or args.prefix_len_multiplier != 1.0
        or args.prefix_root_multiplier != 1
        or args.prompt_len_multiplier != 1.0
        or args.max_isl is not None
226
227
228
        or args.min_isl is not None
        or args.min_osl is not None
        or args.max_osl is not None
229
230
    )

231
232
    if not needs_synthesis and not args.use_expected_osl:
        # No synthesis or modification needed, use original dataset
233
        trace_dataset_path = args.input_dataset
234
        logger.info(
235
            f"Using original trace dataset (no synthesis parameters modified): {trace_dataset_path}"
236
        )
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
    elif not needs_synthesis and args.use_expected_osl:
        # Only inject expected_output_tokens into nvext, no other synthesis
        logger.info("Injecting expected_output_tokens into original trace dataset...")

        # Read original dataset
        requests = []
        with open(args.input_dataset, "r") as f:
            for line in f:
                requests.append(json.loads(line.strip()))

        # Inject expected_output_tokens into nvext for each request
        for request in requests:
            osl = request.get("output_tokens", 0)
            if "nvext" not in request:
                request["nvext"] = {}
            request["nvext"]["expected_output_tokens"] = osl

        # Write modified data to output directory
        trace_dataset_path = os.path.join(
            args.output_dir, "trace_with_expected_osl.jsonl"
        )
        with open(trace_dataset_path, "w") as f:
            for request in requests:
                f.write(json.dumps(request) + "\n")

        logger.info(f"Modified trace data saved to: {trace_dataset_path}")
263
    else:
264
        # Generate synthetic data based on input dataset
265
        logger.info("Generating synthetic trace data...")
266
        logger.info(f"  Base dataset: {args.input_dataset}")
267
268
269
270
271
272
273
        logger.info(
            f"  Num requests: {args.num_requests if args.num_requests else 'all'}"
        )
        logger.info(f"  Speedup ratio: {args.speedup_ratio}")
        logger.info(f"  Prefix len multiplier: {args.prefix_len_multiplier}")
        logger.info(f"  Prefix root multiplier: {args.prefix_root_multiplier}")
        logger.info(f"  Prompt len multiplier: {args.prompt_len_multiplier}")
274
275
276
277
278
279
280
281
282
283
284
285
        logger.info(
            f"  Max ISL: {args.max_isl if args.max_isl else 'no limit'} (filtering)"
        )
        logger.info(
            f"  Min ISL: {args.min_isl if args.min_isl else 'no limit'} (filtering)"
        )
        logger.info(
            f"  Min OSL: {args.min_osl if args.min_osl else 'no clipping'} (clipping)"
        )
        logger.info(
            f"  Max OSL: {args.max_osl if args.max_osl else 'no clipping'} (clipping)"
        )
286
287
288
289
290
291
292
        logger.info(f"  Random seed: {args.seed}")

        # Set random seed for reproducibility
        np.random.seed(args.seed)

        # Create synthesizer
        synthesizer = Synthesizer(
293
            args.input_dataset,
294
295
296
297
298
299
300
301
302
            block_size=args.block_size,
            speedup_ratio=args.speedup_ratio,
            prefix_len_multiplier=args.prefix_len_multiplier,
            prefix_root_multiplier=args.prefix_root_multiplier,
            prompt_len_multiplier=args.prompt_len_multiplier,
        )

        # Determine number of requests
        if args.num_requests is None:
303
304
            # Count requests in original dataset
            with open(args.input_dataset, "r") as f:
305
                num_requests = sum(1 for _ in f)
306
            logger.info(f"Using all {num_requests} requests from input dataset")
307
308
309
310
        else:
            num_requests = args.num_requests

        # Generate synthetic requests
311
312
313
314
315
316
317
        requests = synthesizer.synthesize_requests(
            num_requests,
            max_isl=args.max_isl,
            min_isl=args.min_isl,
            min_osl=args.min_osl,
            max_osl=args.max_osl,
        )
318
319
320
321
        logger.info(f"Generated {len(requests)} synthetic requests")

        # Save synthetic data to a permanent file in output directory
        synthetic_trace_filename = "synthetic_trace.jsonl"
322
        trace_dataset_path = os.path.join(args.output_dir, synthetic_trace_filename)
323

324
325
326
327
328
329
330
331
332
333
334
        # Optionally inject expected_output_tokens into nvext for each request
        if args.use_expected_osl:
            for request in requests:
                # Get the output_tokens (OSL) for this request
                osl = request.get("output_tokens", 0)
                # Initialize or update nvext with expected_output_tokens
                if "nvext" not in request:
                    request["nvext"] = {}
                request["nvext"]["expected_output_tokens"] = osl
            logger.info("Injected expected_output_tokens into nvext for each request")

335
        # Write synthetic data to file
336
        with open(trace_dataset_path, "w") as f:
337
338
339
            for request in requests:
                f.write(json.dumps(request) + "\n")

340
        logger.info(f"Synthetic trace data saved to: {trace_dataset_path}")
341

342
    # Run benchmark with the trace dataset
343
    artifact_dir = os.path.join(args.output_dir, "aiperf_artifacts")
344
345
346
347
348
    os.makedirs(artifact_dir, exist_ok=True)

    run_benchmark_with_trace(
        args.model,
        args.tokenizer,
349
        trace_dataset_path,
350
351
352
353
354
355
356
357
358
359
        artifact_dir,
        args.url,
        args.seed,
    )

    logger.info(f"Results saved to: {artifact_dir}")


if __name__ == "__main__":
    main()