benchmark.py 3.17 KB
Newer Older
1
2
#!/usr/bin/env python3

3
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
5
6
# SPDX-License-Identifier: Apache-2.0

import argparse
7
import re
8
import sys
9
from urllib.parse import urlsplit
10

11
12
from benchmarks.utils.workflow import has_http_scheme, run_benchmark_workflow
from deploy.utils.kubernetes import is_running_in_cluster
13
14


15
16
17
18
19
20
21
def validate_endpoint(endpoint: str) -> None:
    """Validate that endpoint is HTTP endpoint or internal service URL when running in cluster"""
    v = endpoint.strip()
    if is_running_in_cluster():
        # Allow HTTP(S) or internal service URLs like host[:port][/path]
        if has_http_scheme(v):
            pass
22
        else:
23
24
25
26
27
28
29
30
31
32
            parts = urlsplit(f"//{v}")
            host_ok = bool(parts.hostname)
            port_ok = parts.port is None or (1 <= parts.port <= 65535)
            if not (host_ok and port_ok):
                raise ValueError(
                    f"Endpoint must be HTTP(S) or internal service URL. Got: {endpoint}"
                )
    else:
        if not has_http_scheme(v):
            raise ValueError(f"Endpoint must be HTTP endpoint. Got: {endpoint}")
33
34


35
36
37
38
def validate_benchmark_name(name: str) -> None:
    """Validate benchmark name"""
    if not name.strip():
        raise ValueError("Benchmark name cannot be empty")
39

40
    name = name.strip()
41

42
43
44
    # Validate name characters
    if not re.match(r"^[a-zA-Z0-9_-]+$", name):
        raise ValueError(f"Invalid benchmark name: {name}")
45

46
47
48
    # Validate reserved names
    if name.lower() == "plots":
        raise ValueError("Benchmark name 'plots' is reserved")
49
50
51
52
53


def main() -> int:
    parser = argparse.ArgumentParser(description="Benchmark Orchestrator")
    parser.add_argument(
54
55
56
57
58
59
60
61
        "--benchmark-name",
        required=True,
        help="Name/label for this benchmark (used in plots and results)",
    )
    parser.add_argument(
        "--endpoint-url",
        required=True,
        help="Endpoint to benchmark: HTTP(S) URL (e.g., http://localhost:8000) or in-cluster service URL host[:port]",
62
    )
63
    parser.add_argument("--isl", type=int, default=2000, help="Input sequence length")
64
65
66
67
68
69
    parser.add_argument(
        "--std",
        type=int,
        default=10,
        help="Input sequence standard deviation",
    )
70
    parser.add_argument("--osl", type=int, default=256, help="Output sequence length")
71
72
    parser.add_argument(
        "--model",
73
        default="Qwen/Qwen3-0.6B",
74
        help="Model name (must match the model deployed at the endpoint)",
75
76
77
78
79
80
81
82
    )
    parser.add_argument(
        "--output-dir", type=str, default="benchmarks/results", help="Output directory"
    )
    args = parser.parse_args()

    # Validate inputs
    try:
83
84
        validate_benchmark_name(args.benchmark_name)
        validate_endpoint(args.endpoint_url)
85
    except ValueError as e:
86
87
88
89
        print(f"ERROR: {e}")
        return 1

    # Run the benchmark workflow with the parsed inputs
90
    run_benchmark_workflow(
91
        inputs={args.benchmark_name: args.endpoint_url},
92
93
94
95
96
        isl=args.isl,
        std=args.std,
        osl=args.osl,
        model=args.model,
        output_dir=args.output_dir,
97
98
99
100
101
102
    )
    return 0


if __name__ == "__main__":
    sys.exit(main())