config.py 5.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Optional

import yaml


@dataclass
class BenchmarkConfig:
    """A single benchmark configuration: a workflow script + arguments."""

    label: str
    workflow: str
    extra_args: List[str] = field(default_factory=list)


@dataclass
class SweepConfig:
25
26
27
28
29
    """Top-level sweep configuration loaded from YAML with optional CLI overrides.

    Exactly one of ``request_rates`` or ``concurrencies`` must be set.
    The active mode is exposed via ``sweep_mode`` and ``sweep_values``.
    """
30
31

    model: str = "Qwen/Qwen3-VL-30B-A3B-Instruct-FP8"
32
33
    request_rates: Optional[List[int]] = None
    concurrencies: Optional[List[int]] = None
34
    osl: int = 150
35
    conversation_num: Optional[int] = None
36
37
38
39
40
41
42
43
44
45
    warmup_count: int = 5
    port: int = 8000
    timeout: int = 600
    input_files: List[str] = field(default_factory=list)
    configs: List[BenchmarkConfig] = field(default_factory=list)
    output_dir: str = "benchmarks/results/multimodal_default"
    skip_plots: bool = False
    restart_server_every_benchmark: bool = True
    env: Dict[str, str] = field(default_factory=dict)

46
47
48
49
50
51
52
53
54
55
56
57
58
59
    @property
    def sweep_mode(self) -> str:
        """Return ``'request_rate'`` or ``'concurrency'``."""
        if self.concurrencies:
            return "concurrency"
        return "request_rate"

    @property
    def sweep_values(self) -> List[int]:
        """Return the active sweep values (request_rates or concurrencies)."""
        if self.concurrencies:
            return self.concurrencies
        return self.request_rates or []

60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
    def validate(self, repo_root: Optional[Path] = None) -> None:
        """Validate that all referenced files and scripts exist."""
        if not self.input_files:
            raise ValueError("At least one input_file is required.")
        if not self.configs:
            raise ValueError("At least one benchmark config is required.")

        for f in self.input_files:
            if not Path(f).is_file():
                raise FileNotFoundError(f"Input file not found: {f}")

        for cfg in self.configs:
            script = Path(cfg.workflow)
            if repo_root and not script.is_absolute():
                script = repo_root / script
            if not script.is_file():
                raise FileNotFoundError(
                    f"Workflow script not found: {script} (config '{cfg.label}')"
                )

80
81
82
83
84
85
86
87
88
89
90
        if self.request_rates and self.concurrencies:
            raise ValueError(
                "Cannot set both request_rates and concurrencies. Pick one."
            )
        if not self.request_rates and not self.concurrencies:
            raise ValueError(
                "At least one of request_rates or concurrencies is required."
            )


_DEFAULT_REQUEST_RATES: List[int] = [4, 8, 16, 32, 64]
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110


def _parse_benchmark_config(raw: Dict[str, Any]) -> BenchmarkConfig:
    return BenchmarkConfig(
        label=raw["label"],
        workflow=raw["workflow"],
        extra_args=[str(a) for a in raw.get("extra_args", [])],
    )


def load_config(
    yaml_path: str,
    cli_overrides: Optional[Dict[str, Any]] = None,
) -> SweepConfig:
    """Load a SweepConfig from a YAML file, applying optional CLI overrides."""
    with open(yaml_path) as f:
        raw = yaml.safe_load(f)

    configs = [_parse_benchmark_config(c) for c in raw.get("configs", [])]

111
112
113
114
115
116
117
118
119
120
121
122
123
124
    # Resolve sweep mode from YAML — support both keys, default to request_rates.
    yaml_request_rates = raw.get("request_rates")
    yaml_concurrencies = raw.get("concurrencies")

    if yaml_request_rates and yaml_concurrencies:
        raise ValueError(
            f"YAML config {yaml_path} sets both request_rates and concurrencies. "
            "Pick one."
        )

    # Default to request_rates if neither is specified.
    if not yaml_request_rates and not yaml_concurrencies:
        yaml_request_rates = _DEFAULT_REQUEST_RATES

125
    cfg = SweepConfig(
126
127
128
129
        model=raw.get("model", "Qwen/Qwen3-VL-30B-A3B-Instruct-FP8"),
        request_rates=yaml_request_rates,
        concurrencies=yaml_concurrencies,
        osl=raw.get("osl", 150),
130
        conversation_num=raw.get("conversation_num"),
131
132
133
        warmup_count=raw.get("warmup_count", 5),
        port=raw.get("port", 8000),
        timeout=raw.get("timeout", 600),
134
135
        input_files=raw.get("input_files", []),
        configs=configs,
136
        output_dir=raw.get("output_dir", "benchmarks/results/multimodal_default"),
137
        skip_plots=raw.get("skip_plots", False),
138
        restart_server_every_benchmark=raw.get("restart_server_every_benchmark", True),
139
140
141
142
143
144
145
146
147
148
        env=raw.get("env", {}),
    )

    if cli_overrides:
        for key, value in cli_overrides.items():
            if value is None:
                continue
            if hasattr(cfg, key):
                setattr(cfg, key, value)

149
150
151
152
153
154
155
156
157
158
159
160
        # CLI sweep mode override clears the other (mutually exclusive) field.
        if (
            "request_rates" in cli_overrides
            and cli_overrides["request_rates"] is not None
        ):
            cfg.concurrencies = None
        elif (
            "concurrencies" in cli_overrides
            and cli_overrides["concurrencies"] is not None
        ):
            cfg.request_rates = None

161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
    return cfg


def input_file_tag(path: str) -> str:
    """Derive a short directory-safe tag from a JSONL filename."""
    return Path(path).stem.replace(" ", "_")


def resolve_repo_root() -> Path:
    """Walk up from CWD looking for pyproject.toml to find the repo root."""
    candidate = Path(os.getcwd()).resolve()
    while candidate != candidate.parent:
        if (candidate / "pyproject.toml").is_file():
            return candidate
        candidate = candidate.parent
    return Path(os.getcwd()).resolve()