real_data_benchmark.py 2 KB
Newer Older
1
2
#!/usr/bin/env python3

3
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
5
6
7
8
9
# SPDX-License-Identifier: Apache-2.0

import argparse
import os
import subprocess

10
11
from common import (
    add_common_args,
12
13
14
    add_synthesis_args,
    get_aiperf_cmd_for_trace,
    prepare_trace_dataset,
15
16
17
    resolve_tokenizer,
    setup_logger,
)
18

19
logger = setup_logger(__name__)
20
21
22
23
24


def run_benchmark_with_trace(
    model,
    tokenizer,
25
    trace_dataset,
26
27
28
    artifact_dir,
    url,
    seed,
29
    block_size,
30
):
31
32
    """Run aiperf benchmark with a trace dataset"""
    aiperf_cmd = get_aiperf_cmd_for_trace(
33
34
        model,
        tokenizer,
35
        trace_dataset,
36
37
        artifact_dir,
        seed,
38
        block_size,
39
40
41
        url,
    )

42
43
    logger.info(f"Running aiperf with trace dataset: {trace_dataset}")
    logger.info(f"Command: {' '.join(aiperf_cmd)}")
44
45

    try:
46
47
        # Run aiperf and let it output directly to terminal
        subprocess.run(aiperf_cmd, check=True)
48

49
        logger.info("AIPerf profiling completed successfully")
50
51

    except subprocess.CalledProcessError as e:
52
        logger.error(f"AIPerf failed with error code: {e.returncode}")
53
54
55
56
57
58
59
60
61
        logger.error(f"stderr: {e.stderr}")
        raise


def main():
    parser = argparse.ArgumentParser(
        description="Benchmark with real or synthesized mooncake-style trace data"
    )

62
    add_common_args(parser)
63
    add_synthesis_args(parser)
64
65

    args = parser.parse_args()
66
    resolve_tokenizer(args)
67
68
69

    os.makedirs(args.output_dir, exist_ok=True)

70
    _, trace_dataset_path = prepare_trace_dataset(args, args.output_dir, logger)
71

72
    artifact_dir = os.path.join(args.output_dir, "aiperf_artifacts")
73
74
75
76
77
    os.makedirs(artifact_dir, exist_ok=True)

    run_benchmark_with_trace(
        args.model,
        args.tokenizer,
78
        trace_dataset_path,
79
80
81
        artifact_dir,
        args.url,
        args.seed,
82
        args.block_size,
83
84
85
86
87
88
89
    )

    logger.info(f"Results saved to: {artifact_dir}")


if __name__ == "__main__":
    main()