benchmark.py 4.48 KB
Newer Older
1
2
3
4
5
6
#!/usr/bin/env python3

# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

import argparse
7
import re
8
import sys
9
from typing import Dict, Tuple
10
from urllib.parse import urlsplit
11

12
13
from benchmarks.utils.workflow import has_http_scheme, run_benchmark_workflow
from deploy.utils.kubernetes import is_running_in_cluster
14
15
16


def validate_inputs(inputs: Dict[str, str]) -> None:
17
    """Validate that all inputs are HTTP endpoints or internal service URLs when running in cluster"""
18
    for label, value in inputs.items():
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
        v = value.strip()
        if is_running_in_cluster():
            # Allow HTTP(S) or internal service URLs like host[:port][/path]
            if has_http_scheme(v):
                pass
            else:
                parts = urlsplit(f"//{v}")
                host_ok = bool(parts.hostname)
                port_ok = parts.port is None or (1 <= parts.port <= 65535)
                if not (host_ok and port_ok):
                    raise ValueError(
                        f"Input '{label}' must be HTTP(S) or internal service URL. Got: {value}"
                    )
        else:
            if not has_http_scheme(v):
                raise ValueError(f"Input '{label}' must be HTTP endpoint. Got: {value}")
35
36
37

        # Validate reserved labels
        if label.lower() == "plots":
38
            raise ValueError("Label 'plots' is reserved")
39
40
41
42
43


def parse_input(input_str: str) -> Tuple[str, str]:
    """Parse input string in format key=value with additional validation"""
    if "=" not in input_str:
44
        raise ValueError(f"Invalid input format: {input_str}")
45
46
47

    parts = input_str.split("=", 1)  # Split on first '=' only
    if len(parts) != 2:
48
        raise ValueError(f"Invalid input format: {input_str}")
49
50
51
52

    label, value = parts

    if not label.strip():
53
        raise ValueError("Empty label")
54
    if not value.strip():
55
        raise ValueError("Empty value")
56
57
58
59
60
61

    label = label.strip()
    value = value.strip()

    # Validate label characters
    if not re.match(r"^[a-zA-Z0-9_-]+$", label):
62
        raise ValueError(f"Invalid label: {label}")
63
64
65
66
67
68
69
70
71
72

    return label, value


def main() -> int:
    parser = argparse.ArgumentParser(description="Benchmark Orchestrator")
    parser.add_argument(
        "--input",
        action="append",
        dest="inputs",
73
        help="Input in format <label>=<endpoint>. Can be specified multiple times for comparisons.",
74
    )
75
    parser.add_argument("--isl", type=int, default=2000, help="Input sequence length")
76
77
78
79
80
81
    parser.add_argument(
        "--std",
        type=int,
        default=10,
        help="Input sequence standard deviation",
    )
82
    parser.add_argument("--osl", type=int, default=256, help="Output sequence length")
83
84
    parser.add_argument(
        "--model",
85
        default="Qwen/Qwen3-0.6B",
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
        help="Model name",
    )
    parser.add_argument(
        "--output-dir", type=str, default="benchmarks/results", help="Output directory"
    )
    args = parser.parse_args()

    # Validate inputs
    if not args.inputs:
        print("ERROR: At least one --input must be specified")
        return 1

    # Parse inputs
    try:
        parsed_inputs = {}
        for input_str in args.inputs:
            label, value = parse_input(input_str)
            if label in parsed_inputs:
                print(
                    f"ERROR: Duplicate label '{label}' found. Each label must be unique."
                )
                return 1
            parsed_inputs[label] = value

        # Check for plotting limitations
        if len(parsed_inputs) > 12:
            print(
                f"WARNING: You provided {len(parsed_inputs)} inputs, but the plotting system supports up to 12 inputs."
            )
            print(
                "Consider running separate benchmark sessions or grouping related comparisons together."
            )
            print(
                "Continuing with benchmark, but some inputs may not appear in plots..."
            )
            print()

123
        # Validate that inputs are HTTP endpoints or in-cluster service URLs
124
        validate_inputs(parsed_inputs)
125

126
    except ValueError as e:
127
128
129
130
        print(f"ERROR: {e}")
        return 1

    # Run the benchmark workflow with the parsed inputs
131
132
133
134
135
136
137
    run_benchmark_workflow(
        inputs=parsed_inputs,
        isl=args.isl,
        std=args.std,
        osl=args.osl,
        model=args.model,
        output_dir=args.output_dir,
138
139
140
141
142
143
    )
    return 0


if __name__ == "__main__":
    sys.exit(main())