"deploy/helm/charts/vscode:/vscode.git/clone" did not exist on "53a609e538fd787903de3172ccddd36e65017f44"
benchmark.py 4.11 KB
Newer Older
1
2
3
4
5
6
#!/usr/bin/env python3

# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

import argparse
7
import re
8
import sys
9
from typing import Dict, Tuple
10

11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
from benchmarks.utils.workflow import run_benchmark_workflow


def validate_inputs(inputs: Dict[str, str]) -> None:
    """Validate that all inputs are HTTP endpoints"""
    for label, value in inputs.items():
        if not value.lower().startswith(("http://", "https://")):
            raise ValueError(
                f"Input '{label}' must be an HTTP endpoint (starting with http:// or https://). Got: {value}"
            )

        # Validate reserved labels
        if label.lower() == "plots":
            raise ValueError(
                "Label 'plots' is reserved and cannot be used. Please choose a different label."
            )
27
28
29
30
31
32


def parse_input(input_str: str) -> Tuple[str, str]:
    """Parse input string in format key=value with additional validation"""
    if "=" not in input_str:
        raise ValueError(
33
            f"Invalid input format. Expected: <label>=<endpoint>, got: {input_str}"
34
35
36
37
38
        )

    parts = input_str.split("=", 1)  # Split on first '=' only
    if len(parts) != 2:
        raise ValueError(
39
            f"Invalid input format. Expected: <label>=<endpoint>, got: {input_str}"
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
        )

    label, value = parts

    if not label.strip():
        raise ValueError("Label cannot be empty")
    if not value.strip():
        raise ValueError("Value cannot be empty")

    label = label.strip()
    value = value.strip()

    # Validate label characters
    if not re.match(r"^[a-zA-Z0-9_-]+$", label):
        raise ValueError(
            f"Label must contain only letters, numbers, hyphens, and underscores. Invalid label: {label}"
        )

    return label, value


def main() -> int:
    parser = argparse.ArgumentParser(description="Benchmark Orchestrator")
    parser.add_argument(
        "--input",
        action="append",
        dest="inputs",
67
        help="Input in format <label>=<endpoint>. Can be specified multiple times for comparisons.",
68
    )
69
    parser.add_argument("--isl", type=int, default=2000, help="Input sequence length")
70
71
72
73
74
75
    parser.add_argument(
        "--std",
        type=int,
        default=10,
        help="Input sequence standard deviation",
    )
76
    parser.add_argument("--osl", type=int, default=256, help="Output sequence length")
77
78
    parser.add_argument(
        "--model",
79
        default="Qwen/Qwen3-0.6B",
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
        help="Model name",
    )
    parser.add_argument(
        "--output-dir", type=str, default="benchmarks/results", help="Output directory"
    )
    args = parser.parse_args()

    # Validate inputs
    if not args.inputs:
        print("ERROR: At least one --input must be specified")
        return 1

    # Parse inputs
    try:
        parsed_inputs = {}
        for input_str in args.inputs:
            label, value = parse_input(input_str)
            if label in parsed_inputs:
                print(
                    f"ERROR: Duplicate label '{label}' found. Each label must be unique."
                )
                return 1
            parsed_inputs[label] = value

        # Check for plotting limitations
        if len(parsed_inputs) > 12:
            print(
                f"WARNING: You provided {len(parsed_inputs)} inputs, but the plotting system supports up to 12 inputs."
            )
            print(
                "Consider running separate benchmark sessions or grouping related comparisons together."
            )
            print(
                "Continuing with benchmark, but some inputs may not appear in plots..."
            )
            print()

117
118
        # Validate that all inputs are HTTP endpoints
        validate_inputs(parsed_inputs)
119

120
    except ValueError as e:
121
122
123
124
        print(f"ERROR: {e}")
        return 1

    # Run the benchmark workflow with the parsed inputs
125
126
127
128
129
130
131
    run_benchmark_workflow(
        inputs=parsed_inputs,
        isl=args.isl,
        std=args.std,
        osl=args.osl,
        model=args.model,
        output_dir=args.output_dir,
132
133
134
135
136
137
    )
    return 0


if __name__ == "__main__":
    sys.exit(main())