real_data_benchmark.py 10.1 KB
Newer Older
1
2
#!/usr/bin/env python3

3
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
5
6
7
8
9
10
11
# SPDX-License-Identifier: Apache-2.0

import argparse
import json
import os
import subprocess

import numpy as np
12
13
14
15
16
17
18
from common import (
    DEFAULT_MOONCAKE_BLOCK_SIZE,
    add_common_args,
    get_common_aiperf_flags,
    resolve_tokenizer,
    setup_logger,
)
19
20
from prefix_data_generator.synthesizer import Synthesizer

21
logger = setup_logger(__name__)
22
23


24
def get_aiperf_cmd_for_trace(
25
26
    model,
    tokenizer,
27
    input_dataset,
28
29
    artifact_dir,
    seed,
30
    block_size,
31
32
    url="http://localhost:8888",
):
33
    cmd = [
34
        "aiperf",
35
36
37
38
39
40
41
42
        "profile",
        "--model",
        model,
        "--tokenizer",
        tokenizer,
        "--url",
        url,
        "--input-file",
43
        f"{input_dataset}",
44
45
        "--custom-dataset-type",
        "mooncake_trace",
46
        "--fixed-schedule-auto-offset",
47
48
        "--prompt-input-tokens-block-size",
        str(block_size),
49
50
51
52
53
        "--random-seed",
        str(seed),
        "--artifact-dir",
        artifact_dir,
    ]
54
55
    cmd.extend(get_common_aiperf_flags())
    return cmd
56
57
58
59
60


def run_benchmark_with_trace(
    model,
    tokenizer,
61
    trace_dataset,
62
63
64
    artifact_dir,
    url,
    seed,
65
    block_size,
66
):
67
68
    """Run aiperf benchmark with a trace dataset"""
    aiperf_cmd = get_aiperf_cmd_for_trace(
69
70
        model,
        tokenizer,
71
        trace_dataset,
72
73
        artifact_dir,
        seed,
74
        block_size,
75
76
77
        url,
    )

78
79
    logger.info(f"Running aiperf with trace dataset: {trace_dataset}")
    logger.info(f"Command: {' '.join(aiperf_cmd)}")
80
81

    try:
82
83
        # Run aiperf and let it output directly to terminal
        subprocess.run(aiperf_cmd, check=True)
84

85
        logger.info("AIPerf profiling completed successfully")
86
87

    except subprocess.CalledProcessError as e:
88
        logger.error(f"AIPerf failed with error code: {e.returncode}")
89
90
91
92
93
94
95
96
97
        logger.error(f"stderr: {e.stderr}")
        raise


def main():
    parser = argparse.ArgumentParser(
        description="Benchmark with real or synthesized mooncake-style trace data"
    )

98
99
100
    # Common arguments
    add_common_args(parser)

101
102
103
104
105
106
107
    parser.add_argument(
        "--output-dir",
        type=str,
        default="real_data_benchmark_results",
        help="Output directory for results",
    )

108
    # Trace dataset and synthesis configuration (similar to synthesizer.py)
109
    parser.add_argument(
110
        "--input-dataset",
111
112
        type=str,
        default="mooncake_trace.jsonl",
113
        help="Path to the input mooncake-style trace dataset file",
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
    )
    parser.add_argument(
        "--num-requests",
        type=int,
        default=None,
        help="Number of requests to synthesize (default: use all from input file)",
    )
    parser.add_argument(
        "--speedup-ratio",
        type=float,
        default=1.0,
        help="Factor to speed up request intervals (default: 1.0)",
    )
    parser.add_argument(
        "--prefix-len-multiplier",
        type=float,
        default=1.0,
        help="Multiplier for prefix lengths (default: 1.0)",
    )
    parser.add_argument(
        "--prefix-root-multiplier",
        type=int,
        default=1,
        help="Number of times to replicate the core radix tree (default: 1)",
    )
    parser.add_argument(
        "--prompt-len-multiplier",
        type=float,
        default=1.0,
        help="Multiplier for leaf path lengths (default: 1.0, use <1 for shorter prompts)",
    )
    parser.add_argument(
        "--max-isl",
        type=int,
        default=None,
        help="Maximum input sequence length to include in output (default: None, no filtering)",
    )
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
    parser.add_argument(
        "--min-isl",
        type=int,
        default=None,
        help="Minimum input sequence length to include in output (default: None, no filtering)",
    )
    parser.add_argument(
        "--min-osl",
        type=int,
        default=None,
        help="Minimum output sequence length - clips values below this threshold (default: None, no clipping)",
    )
    parser.add_argument(
        "--max-osl",
        type=int,
        default=None,
        help="Maximum output sequence length - clips values above this threshold (default: None, no clipping)",
    )
169
170
171
    parser.add_argument(
        "--block-size",
        type=int,
172
173
        default=DEFAULT_MOONCAKE_BLOCK_SIZE,
        help=f"Block size for prefilling and decoding (default: {DEFAULT_MOONCAKE_BLOCK_SIZE})",
174
    )
175
176

    args = parser.parse_args()
177
    resolve_tokenizer(args)
178
179
180
181
182
183
184
185
186
187
188
189
190

    # Create output directory
    os.makedirs(args.output_dir, exist_ok=True)

    # Determine whether to use original or synthesized data
    # Check if any synthesis parameters are non-default
    needs_synthesis = (
        args.num_requests is not None
        or args.speedup_ratio != 1.0
        or args.prefix_len_multiplier != 1.0
        or args.prefix_root_multiplier != 1
        or args.prompt_len_multiplier != 1.0
        or args.max_isl is not None
191
192
193
        or args.min_isl is not None
        or args.min_osl is not None
        or args.max_osl is not None
194
195
    )

196
197
    if not needs_synthesis and not args.use_expected_osl:
        # No synthesis or modification needed, use original dataset
198
        trace_dataset_path = args.input_dataset
199
        logger.info(
200
            f"Using original trace dataset (no synthesis parameters modified): {trace_dataset_path}"
201
        )
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
    elif not needs_synthesis and args.use_expected_osl:
        # Only inject expected_output_tokens into nvext, no other synthesis
        logger.info("Injecting expected_output_tokens into original trace dataset...")

        # Read original dataset
        requests = []
        with open(args.input_dataset, "r") as f:
            for line in f:
                requests.append(json.loads(line.strip()))

        # Inject expected_output_tokens into nvext for each request
        for request in requests:
            osl = request.get("output_tokens", 0)
            if "nvext" not in request:
                request["nvext"] = {}
            request["nvext"]["expected_output_tokens"] = osl

        # Write modified data to output directory
        trace_dataset_path = os.path.join(
            args.output_dir, "trace_with_expected_osl.jsonl"
        )
        with open(trace_dataset_path, "w") as f:
            for request in requests:
                f.write(json.dumps(request) + "\n")

        logger.info(f"Modified trace data saved to: {trace_dataset_path}")
228
    else:
229
        # Generate synthetic data based on input dataset
230
        logger.info("Generating synthetic trace data...")
231
        logger.info(f"  Base dataset: {args.input_dataset}")
232
233
234
235
236
237
238
        logger.info(
            f"  Num requests: {args.num_requests if args.num_requests else 'all'}"
        )
        logger.info(f"  Speedup ratio: {args.speedup_ratio}")
        logger.info(f"  Prefix len multiplier: {args.prefix_len_multiplier}")
        logger.info(f"  Prefix root multiplier: {args.prefix_root_multiplier}")
        logger.info(f"  Prompt len multiplier: {args.prompt_len_multiplier}")
239
240
241
242
243
244
245
246
247
248
249
250
        logger.info(
            f"  Max ISL: {args.max_isl if args.max_isl else 'no limit'} (filtering)"
        )
        logger.info(
            f"  Min ISL: {args.min_isl if args.min_isl else 'no limit'} (filtering)"
        )
        logger.info(
            f"  Min OSL: {args.min_osl if args.min_osl else 'no clipping'} (clipping)"
        )
        logger.info(
            f"  Max OSL: {args.max_osl if args.max_osl else 'no clipping'} (clipping)"
        )
251
252
253
254
255
256
257
        logger.info(f"  Random seed: {args.seed}")

        # Set random seed for reproducibility
        np.random.seed(args.seed)

        # Create synthesizer
        synthesizer = Synthesizer(
258
            args.input_dataset,
259
260
261
262
263
264
265
266
267
            block_size=args.block_size,
            speedup_ratio=args.speedup_ratio,
            prefix_len_multiplier=args.prefix_len_multiplier,
            prefix_root_multiplier=args.prefix_root_multiplier,
            prompt_len_multiplier=args.prompt_len_multiplier,
        )

        # Determine number of requests
        if args.num_requests is None:
268
269
            # Count requests in original dataset
            with open(args.input_dataset, "r") as f:
270
                num_requests = sum(1 for _ in f)
271
            logger.info(f"Using all {num_requests} requests from input dataset")
272
273
274
275
        else:
            num_requests = args.num_requests

        # Generate synthetic requests
276
277
278
279
280
281
282
        requests = synthesizer.synthesize_requests(
            num_requests,
            max_isl=args.max_isl,
            min_isl=args.min_isl,
            min_osl=args.min_osl,
            max_osl=args.max_osl,
        )
283
284
285
286
        logger.info(f"Generated {len(requests)} synthetic requests")

        # Save synthetic data to a permanent file in output directory
        synthetic_trace_filename = "synthetic_trace.jsonl"
287
        trace_dataset_path = os.path.join(args.output_dir, synthetic_trace_filename)
288

289
290
291
292
293
294
295
296
297
298
299
        # Optionally inject expected_output_tokens into nvext for each request
        if args.use_expected_osl:
            for request in requests:
                # Get the output_tokens (OSL) for this request
                osl = request.get("output_tokens", 0)
                # Initialize or update nvext with expected_output_tokens
                if "nvext" not in request:
                    request["nvext"] = {}
                request["nvext"]["expected_output_tokens"] = osl
            logger.info("Injected expected_output_tokens into nvext for each request")

300
        # Write synthetic data to file
301
        with open(trace_dataset_path, "w") as f:
302
303
304
            for request in requests:
                f.write(json.dumps(request) + "\n")

305
        logger.info(f"Synthetic trace data saved to: {trace_dataset_path}")
306

307
    # Run benchmark with the trace dataset
308
    artifact_dir = os.path.join(args.output_dir, "aiperf_artifacts")
309
310
311
312
313
    os.makedirs(artifact_dir, exist_ok=True)

    run_benchmark_with_trace(
        args.model,
        args.tokenizer,
314
        trace_dataset_path,
315
316
317
        artifact_dir,
        args.url,
        args.seed,
318
        args.block_size,
319
320
321
322
323
324
325
    )

    logger.info(f"Results saved to: {artifact_dir}")


if __name__ == "__main__":
    main()