#!/usr/bin/env python3 # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 import argparse import json import os import subprocess import numpy as np from common import ( DEFAULT_MOONCAKE_BLOCK_SIZE, add_common_args, get_common_aiperf_flags, resolve_tokenizer, setup_logger, ) from prefix_data_generator.synthesizer import Synthesizer logger = setup_logger(__name__) def get_aiperf_cmd_for_trace( model, tokenizer, input_dataset, artifact_dir, seed, block_size, url="http://localhost:8888", ): cmd = [ "aiperf", "profile", "--model", model, "--tokenizer", tokenizer, "--url", url, "--input-file", f"{input_dataset}", "--custom-dataset-type", "mooncake_trace", "--fixed-schedule-auto-offset", "--prompt-input-tokens-block-size", str(block_size), "--random-seed", str(seed), "--artifact-dir", artifact_dir, ] cmd.extend(get_common_aiperf_flags()) return cmd def run_benchmark_with_trace( model, tokenizer, trace_dataset, artifact_dir, url, seed, block_size, ): """Run aiperf benchmark with a trace dataset""" aiperf_cmd = get_aiperf_cmd_for_trace( model, tokenizer, trace_dataset, artifact_dir, seed, block_size, url, ) logger.info(f"Running aiperf with trace dataset: {trace_dataset}") logger.info(f"Command: {' '.join(aiperf_cmd)}") try: # Run aiperf and let it output directly to terminal subprocess.run(aiperf_cmd, check=True) logger.info("AIPerf profiling completed successfully") except subprocess.CalledProcessError as e: logger.error(f"AIPerf failed with error code: {e.returncode}") logger.error(f"stderr: {e.stderr}") raise def main(): parser = argparse.ArgumentParser( description="Benchmark with real or synthesized mooncake-style trace data" ) # Common arguments add_common_args(parser) parser.add_argument( "--output-dir", type=str, default="real_data_benchmark_results", help="Output directory for results", ) # Trace dataset and synthesis configuration (similar to synthesizer.py) parser.add_argument( "--input-dataset", type=str, default="mooncake_trace.jsonl", help="Path to the input mooncake-style trace dataset file", ) parser.add_argument( "--num-requests", type=int, default=None, help="Number of requests to synthesize (default: use all from input file)", ) parser.add_argument( "--speedup-ratio", type=float, default=1.0, help="Factor to speed up request intervals (default: 1.0)", ) parser.add_argument( "--prefix-len-multiplier", type=float, default=1.0, help="Multiplier for prefix lengths (default: 1.0)", ) parser.add_argument( "--prefix-root-multiplier", type=int, default=1, help="Number of times to replicate the core radix tree (default: 1)", ) parser.add_argument( "--prompt-len-multiplier", type=float, default=1.0, help="Multiplier for leaf path lengths (default: 1.0, use <1 for shorter prompts)", ) parser.add_argument( "--max-isl", type=int, default=None, help="Maximum input sequence length to include in output (default: None, no filtering)", ) parser.add_argument( "--min-isl", type=int, default=None, help="Minimum input sequence length to include in output (default: None, no filtering)", ) parser.add_argument( "--min-osl", type=int, default=None, help="Minimum output sequence length - clips values below this threshold (default: None, no clipping)", ) parser.add_argument( "--max-osl", type=int, default=None, help="Maximum output sequence length - clips values above this threshold (default: None, no clipping)", ) parser.add_argument( "--block-size", type=int, default=DEFAULT_MOONCAKE_BLOCK_SIZE, help=f"Block size for prefilling and decoding (default: {DEFAULT_MOONCAKE_BLOCK_SIZE})", ) args = parser.parse_args() resolve_tokenizer(args) # Create output directory os.makedirs(args.output_dir, exist_ok=True) # Determine whether to use original or synthesized data # Check if any synthesis parameters are non-default needs_synthesis = ( args.num_requests is not None or args.speedup_ratio != 1.0 or args.prefix_len_multiplier != 1.0 or args.prefix_root_multiplier != 1 or args.prompt_len_multiplier != 1.0 or args.max_isl is not None or args.min_isl is not None or args.min_osl is not None or args.max_osl is not None ) if not needs_synthesis and not args.use_expected_osl: # No synthesis or modification needed, use original dataset trace_dataset_path = args.input_dataset logger.info( f"Using original trace dataset (no synthesis parameters modified): {trace_dataset_path}" ) elif not needs_synthesis and args.use_expected_osl: # Only inject expected_output_tokens into nvext, no other synthesis logger.info("Injecting expected_output_tokens into original trace dataset...") # Read original dataset requests = [] with open(args.input_dataset, "r") as f: for line in f: requests.append(json.loads(line.strip())) # Inject expected_output_tokens into nvext for each request for request in requests: osl = request.get("output_tokens", 0) if "nvext" not in request: request["nvext"] = {} request["nvext"]["expected_output_tokens"] = osl # Write modified data to output directory trace_dataset_path = os.path.join( args.output_dir, "trace_with_expected_osl.jsonl" ) with open(trace_dataset_path, "w") as f: for request in requests: f.write(json.dumps(request) + "\n") logger.info(f"Modified trace data saved to: {trace_dataset_path}") else: # Generate synthetic data based on input dataset logger.info("Generating synthetic trace data...") logger.info(f" Base dataset: {args.input_dataset}") logger.info( f" Num requests: {args.num_requests if args.num_requests else 'all'}" ) logger.info(f" Speedup ratio: {args.speedup_ratio}") logger.info(f" Prefix len multiplier: {args.prefix_len_multiplier}") logger.info(f" Prefix root multiplier: {args.prefix_root_multiplier}") logger.info(f" Prompt len multiplier: {args.prompt_len_multiplier}") logger.info( f" Max ISL: {args.max_isl if args.max_isl else 'no limit'} (filtering)" ) logger.info( f" Min ISL: {args.min_isl if args.min_isl else 'no limit'} (filtering)" ) logger.info( f" Min OSL: {args.min_osl if args.min_osl else 'no clipping'} (clipping)" ) logger.info( f" Max OSL: {args.max_osl if args.max_osl else 'no clipping'} (clipping)" ) logger.info(f" Random seed: {args.seed}") # Set random seed for reproducibility np.random.seed(args.seed) # Create synthesizer synthesizer = Synthesizer( args.input_dataset, block_size=args.block_size, speedup_ratio=args.speedup_ratio, prefix_len_multiplier=args.prefix_len_multiplier, prefix_root_multiplier=args.prefix_root_multiplier, prompt_len_multiplier=args.prompt_len_multiplier, ) # Determine number of requests if args.num_requests is None: # Count requests in original dataset with open(args.input_dataset, "r") as f: num_requests = sum(1 for _ in f) logger.info(f"Using all {num_requests} requests from input dataset") else: num_requests = args.num_requests # Generate synthetic requests requests = synthesizer.synthesize_requests( num_requests, max_isl=args.max_isl, min_isl=args.min_isl, min_osl=args.min_osl, max_osl=args.max_osl, ) logger.info(f"Generated {len(requests)} synthetic requests") # Save synthetic data to a permanent file in output directory synthetic_trace_filename = "synthetic_trace.jsonl" trace_dataset_path = os.path.join(args.output_dir, synthetic_trace_filename) # Optionally inject expected_output_tokens into nvext for each request if args.use_expected_osl: for request in requests: # Get the output_tokens (OSL) for this request osl = request.get("output_tokens", 0) # Initialize or update nvext with expected_output_tokens if "nvext" not in request: request["nvext"] = {} request["nvext"]["expected_output_tokens"] = osl logger.info("Injected expected_output_tokens into nvext for each request") # Write synthetic data to file with open(trace_dataset_path, "w") as f: for request in requests: f.write(json.dumps(request) + "\n") logger.info(f"Synthetic trace data saved to: {trace_dataset_path}") # Run benchmark with the trace dataset artifact_dir = os.path.join(args.output_dir, "aiperf_artifacts") os.makedirs(artifact_dir, exist_ok=True) run_benchmark_with_trace( args.model, args.tokenizer, trace_dataset_path, artifact_dir, args.url, args.seed, args.block_size, ) logger.info(f"Results saved to: {artifact_dir}") if __name__ == "__main__": main()