benchmark.py 3.79 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
#!/usr/bin/env python3

# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

import argparse
import asyncio
import sys
from typing import Tuple

from benchmarks.utils.workflow import categorize_inputs, run_benchmark_workflow


def parse_input(input_str: str) -> Tuple[str, str]:
    """Parse input string in format key=value with additional validation"""
    if "=" not in input_str:
        raise ValueError(
            f"Invalid input format. Expected: <label>=<manifest_path_or_endpoint>, got: {input_str}"
        )

    parts = input_str.split("=", 1)  # Split on first '=' only
    if len(parts) != 2:
        raise ValueError(
            f"Invalid input format. Expected: <label>=<manifest_path_or_endpoint>, got: {input_str}"
        )

    label, value = parts

    if not label.strip():
        raise ValueError("Label cannot be empty")
    if not value.strip():
        raise ValueError("Value cannot be empty")

    label = label.strip()
    value = value.strip()

    # Validate label characters
    import re

    if not re.match(r"^[a-zA-Z0-9_-]+$", label):
        raise ValueError(
            f"Label must contain only letters, numbers, hyphens, and underscores. Invalid label: {label}"
        )

    return label, value


def main() -> int:
    parser = argparse.ArgumentParser(description="Benchmark Orchestrator")
    parser.add_argument(
        "--input",
        action="append",
        dest="inputs",
        help="Input in format <label>=<manifest_path_or_endpoint>. Can be specified multiple times for comparisons.",
    )
    parser.add_argument("--namespace", required=True, help="Kubernetes namespace")
57
    parser.add_argument("--isl", type=int, default=2000, help="Input sequence length")
58
59
60
61
62
63
    parser.add_argument(
        "--std",
        type=int,
        default=10,
        help="Input sequence standard deviation",
    )
64
    parser.add_argument("--osl", type=int, default=256, help="Output sequence length")
65
66
    parser.add_argument(
        "--model",
67
        default="Qwen/Qwen3-0.6B",
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
        help="Model name",
    )
    parser.add_argument(
        "--output-dir", type=str, default="benchmarks/results", help="Output directory"
    )
    args = parser.parse_args()

    # Validate inputs
    if not args.inputs:
        print("ERROR: At least one --input must be specified")
        return 1

    # Parse inputs
    try:
        parsed_inputs = {}
        for input_str in args.inputs:
            label, value = parse_input(input_str)
            if label in parsed_inputs:
                print(
                    f"ERROR: Duplicate label '{label}' found. Each label must be unique."
                )
                return 1
            parsed_inputs[label] = value

        # Check for plotting limitations
        if len(parsed_inputs) > 12:
            print(
                f"WARNING: You provided {len(parsed_inputs)} inputs, but the plotting system supports up to 12 inputs."
            )
            print(
                "Consider running separate benchmark sessions or grouping related comparisons together."
            )
            print(
                "Continuing with benchmark, but some inputs may not appear in plots..."
            )
            print()

        endpoints, manifests = categorize_inputs(parsed_inputs)

    except (ValueError, FileNotFoundError) as e:
        print(f"ERROR: {e}")
        return 1

    # Run the benchmark workflow with the parsed inputs
    asyncio.run(
        run_benchmark_workflow(
            namespace=args.namespace,
            inputs=parsed_inputs,
            isl=args.isl,
            std=args.std,
            osl=args.osl,
            model=args.model,
            output_dir=args.output_dir,
        )
    )
    return 0


if __name__ == "__main__":
    sys.exit(main())