config.py 4.17 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Optional

import yaml


@dataclass
class BenchmarkConfig:
    """A single benchmark configuration: a workflow script + arguments."""

    label: str
    workflow: str
    extra_args: List[str] = field(default_factory=list)


@dataclass
class SweepConfig:
    """Top-level sweep configuration loaded from YAML with optional CLI overrides."""

    model: str = "Qwen/Qwen3-VL-30B-A3B-Instruct-FP8"
    concurrencies: List[int] = field(default_factory=lambda: [1, 2, 4, 8, 16, 32])
    osl: int = 150
    request_count: int = 1000
    warmup_count: int = 5
    port: int = 8000
    timeout: int = 600
    input_files: List[str] = field(default_factory=list)
    configs: List[BenchmarkConfig] = field(default_factory=list)
    output_dir: str = "benchmarks/results/multimodal_default"
    skip_plots: bool = False
    restart_server_every_benchmark: bool = True
    env: Dict[str, str] = field(default_factory=dict)

    def validate(self, repo_root: Optional[Path] = None) -> None:
        """Validate that all referenced files and scripts exist."""
        if not self.input_files:
            raise ValueError("At least one input_file is required.")
        if not self.configs:
            raise ValueError("At least one benchmark config is required.")

        for f in self.input_files:
            if not Path(f).is_file():
                raise FileNotFoundError(f"Input file not found: {f}")

        for cfg in self.configs:
            script = Path(cfg.workflow)
            if repo_root and not script.is_absolute():
                script = repo_root / script
            if not script.is_file():
                raise FileNotFoundError(
                    f"Workflow script not found: {script} (config '{cfg.label}')"
                )

        if not self.concurrencies:
            raise ValueError("At least one concurrency level is required.")


def _parse_benchmark_config(raw: Dict[str, Any]) -> BenchmarkConfig:
    return BenchmarkConfig(
        label=raw["label"],
        workflow=raw["workflow"],
        extra_args=[str(a) for a in raw.get("extra_args", [])],
    )


def load_config(
    yaml_path: str,
    cli_overrides: Optional[Dict[str, Any]] = None,
) -> SweepConfig:
    """Load a SweepConfig from a YAML file, applying optional CLI overrides."""
    with open(yaml_path) as f:
        raw = yaml.safe_load(f)

    defaults = SweepConfig()
    configs = [_parse_benchmark_config(c) for c in raw.get("configs", [])]

    cfg = SweepConfig(
        model=raw.get("model", defaults.model),
        concurrencies=raw.get("concurrencies", defaults.concurrencies),
        osl=raw.get("osl", defaults.osl),
        request_count=raw.get("request_count", defaults.request_count),
        warmup_count=raw.get("warmup_count", defaults.warmup_count),
        port=raw.get("port", defaults.port),
        timeout=raw.get("timeout", defaults.timeout),
        input_files=raw.get("input_files", []),
        configs=configs,
        output_dir=raw.get("output_dir", defaults.output_dir),
        skip_plots=raw.get("skip_plots", False),
        restart_server_every_benchmark=raw.get(
            "restart_server_every_benchmark",
            defaults.restart_server_every_benchmark,
        ),
        env=raw.get("env", {}),
    )

    if cli_overrides:
        for key, value in cli_overrides.items():
            if value is None:
                continue
            if hasattr(cfg, key):
                setattr(cfg, key, value)

    return cfg


def input_file_tag(path: str) -> str:
    """Derive a short directory-safe tag from a JSONL filename."""
    return Path(path).stem.replace(" ", "_")


def resolve_repo_root() -> Path:
    """Walk up from CWD looking for pyproject.toml to find the repo root."""
    candidate = Path(os.getcwd()).resolve()
    while candidate != candidate.parent:
        if (candidate / "pyproject.toml").is_file():
            return candidate
        candidate = candidate.parent
    return Path(os.getcwd()).resolve()