orchestrator.py 7.71 KB
Newer Older
1
2
3
4
5
6
7
8
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

from pathlib import Path
from typing import List, Optional

9
from .config import BenchmarkConfig, SweepConfig, input_file_tag, resolve_repo_root
10
from .dataset_shape import count_session_ids
11
from .runner import run_aiperf_single
12
13
14
15
16
17
18
19
20
21
from .server import ServerManager


def _resolve_workflow(workflow: str, repo_root: Path) -> str:
    p = Path(workflow)
    if p.is_absolute():
        return str(p)
    return str(repo_root / p)


22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
def _resolve_conversation_num(config: SweepConfig, input_file: str) -> int:
    """Pick conversation_num for this input file: explicit value from config wins,
    otherwise derive from the JSONL's unique session_id count. Error if an
    explicit value exceeds the JSONL's capacity (sampler would wrap)."""
    detected = count_session_ids(input_file)
    if config.conversation_num is None:
        print(
            f"  conversation_num derived from {input_file}: {detected}",
            flush=True,
        )
        return detected
    if config.conversation_num > detected:
        raise ValueError(
            f"conversation_num={config.conversation_num} exceeds unique "
            f"session_id count ({detected}) in {input_file}. SequentialSampler "
            f"would wrap. Set conversation_num <= {detected} or reshape the JSONL."
        )
    return config.conversation_num


42
43
44
45
46
47
48
49
50
51
def _print_banner(title: str, char: str = "=", width: int = 70) -> None:
    print(f"\n{char * width}")
    print(f"  {title}")
    print(f"{char * width}", flush=True)


def run_sweep(
    config: SweepConfig,
    repo_root: Optional[Path] = None,
) -> None:
52
    """Execute the full benchmark sweep: for each config x input file x sweep value."""
53
54
55
56
    if repo_root is None:
        repo_root = resolve_repo_root()

    output_base = Path(config.output_dir)
57
58
    sweep_mode = config.sweep_mode
    sweep_values = config.sweep_values
59
60
61
62
63
64
65
66

    _print_banner("Multimodal Benchmark Sweep")
    print(f"  Model:         {config.model}")
    print(f"  Input files:   {len(config.input_files)}")
    for f in config.input_files:
        print(f"                   {f}")
    labels = [c.label for c in config.configs]
    print(f"  Configs:       {labels}")
67
68
    print(f"  Sweep mode:    {sweep_mode}")
    print(f"  Sweep values:  {sweep_values}")
69
    print(f"  OSL:           {config.osl}")
70
71
    if config.conversation_num is not None:
        print(f"  Conversations: {config.conversation_num} per {sweep_mode}")
72
73
74
    print(
        f"  Restart:       {'every run' if config.restart_server_every_benchmark else 'per config'}"
    )
75
76
77
78
79
80
81
    print(f"  Output:        {output_base}")
    print(flush=True)

    server = ServerManager(port=config.port, timeout=config.timeout)
    env_overrides = dict(config.env) if config.env else {}

    try:
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
        for bench_cfg in config.configs:
            _run_config(
                bench_cfg=bench_cfg,
                config=config,
                server=server,
                output_base=output_base,
                sweep_mode=sweep_mode,
                sweep_values=sweep_values,
                env_overrides=env_overrides,
                repo_root=repo_root,
            )

        if not config.skip_plots:
            for input_file in config.input_files:
                file_tag = input_file_tag(input_file)
97
                _generate_plots_for_file(
98
                    output_base / file_tag,
99
100
101
102
103
104
105
106
107
                    [c.label for c in config.configs],
                )
    finally:
        if server.is_running:
            server.stop()

    _print_summary(config, output_base)


108
109
def _run_config(
    bench_cfg: BenchmarkConfig,
110
    config: SweepConfig,
111
112
113
114
    server: ServerManager,
    output_base: Path,
    sweep_mode: str,
    sweep_values: List[int],
115
    env_overrides: dict,
116
    repo_root: Path,
117
) -> None:
118
119
120
121
122
    """Run all sweep values for a single benchmark config."""
    workflow_abs = _resolve_workflow(bench_cfg.workflow, repo_root)
    _print_banner(f"Config: {bench_cfg.label}", char="#")

    # Collect pending runs, skipping those with existing results.
123
    pending_runs: List[tuple[str, str, int, Path, int]] = []
124
125
126
127
    for input_file in config.input_files:
        file_tag = input_file_tag(input_file)
        sweep_dir = output_base / file_tag / bench_cfg.label

128
129
        conversation_num = _resolve_conversation_num(config, input_file)

130
131
132
133
134
135
136
137
138
139
        for value in sorted(sweep_values):
            artifact_dir = sweep_dir / f"{sweep_mode}{value}"

            if (artifact_dir / "profile_export_aiperf.json").exists():
                print(
                    f"  SKIP {bench_cfg.label} {sweep_mode}={value} "
                    f"(results exist in {artifact_dir})",
                    flush=True,
                )
            else:
140
141
142
                pending_runs.append(
                    (input_file, file_tag, value, artifact_dir, conversation_num)
                )
143

144
145
146
147
148
    if not pending_runs:
        print(f"  All runs skipped for {bench_cfg.label}", flush=True)
        return

    if not config.restart_server_every_benchmark:
149
        server.start(
150
            workflow_script=workflow_abs,
151
152
153
154
            model=config.model,
            extra_args=bench_cfg.extra_args,
            env_overrides=env_overrides,
        )
155
156

    try:
157
        for input_file, file_tag, value, artifact_dir, conversation_num in pending_runs:
158
159
160
            _print_banner(
                f"[{file_tag}] Config: {bench_cfg.label}  " f"{sweep_mode}={value}",
                char="-",
161
162
            )

163
164
165
166
167
168
169
170
171
172
173
174
175
176
            if config.restart_server_every_benchmark:
                server.start(
                    workflow_script=workflow_abs,
                    model=config.model,
                    extra_args=bench_cfg.extra_args,
                    env_overrides=env_overrides,
                )

            try:
                run_aiperf_single(
                    model=config.model,
                    port=config.port,
                    sweep_mode=sweep_mode,
                    sweep_value=value,
177
                    conversation_num=conversation_num,
178
179
180
181
182
183
184
185
186
187
188
                    warmup_count=config.warmup_count,
                    input_file=input_file,
                    osl=config.osl,
                    artifact_dir=artifact_dir,
                )
            finally:
                if config.restart_server_every_benchmark:
                    server.stop()
    finally:
        if not config.restart_server_every_benchmark:
            server.stop()
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227


def _generate_plots_for_file(
    file_output_dir: Path,
    labels: List[str],
) -> None:
    """Generate comparison plots for one input file across all configs."""
    try:
        from benchmarks.utils.plot import generate_plots

        plots_dir = file_output_dir / "plots"
        print(f"\nGenerating plots -> {plots_dir}", flush=True)
        generate_plots(
            base_output_dir=file_output_dir,
            output_dir=plots_dir,
            benchmark_names=labels,
        )
    except ImportError:
        print(
            "WARNING: benchmarks.utils.plot not importable; skipping plots.",
            flush=True,
        )
    except Exception as exc:
        print(f"WARNING: Plot generation failed: {exc}", flush=True)


def _print_summary(config: SweepConfig, output_base: Path) -> None:
    _print_banner("Sweep Complete!")
    print(f"  Results: {output_base}")
    for input_file in config.input_files:
        tag = input_file_tag(input_file)
        print(f"  [{tag}]:")
        for cfg in config.configs:
            result_dir = output_base / tag / cfg.label
            print(f"    {cfg.label}: {result_dir}")
        if not config.skip_plots:
            plots_dir = output_base / tag / "plots"
            print(f"    plots:  {plots_dir}")
    print(flush=True)