test_main.py 3.94 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

import json
import sys
from pathlib import Path
from unittest.mock import patch

import pytest

pytest.importorskip("PIL", reason="Pillow required for image generation benchmarks")

# The generators use relative imports, so add the source dir to sys.path
JSONL_DIR = Path(__file__).resolve().parents[4] / "benchmarks" / "multimodal" / "jsonl"
sys.path.insert(0, str(JSONL_DIR))

from main import main  # noqa: E402

pytestmark = [pytest.mark.unit, pytest.mark.pre_merge, pytest.mark.gpu_0]


def _run_main(tmp_path: Path, argv: list[str]) -> list[dict]:
    """Run main() with given argv, return parsed JSONL lines."""
    with patch("sys.argv", ["main.py"] + argv):
        main()
    jsonl_files = list(tmp_path.glob("*.jsonl"))
    assert len(jsonl_files) == 1, f"Expected 1 JSONL file, got {jsonl_files}"
    with open(jsonl_files[0]) as f:
        return [json.loads(line) for line in f if line.strip()]


class TestSingleTurnDefault:
    """single-turn is the default when no subcommand is given."""

    def test_default_produces_independent_requests(self, tmp_path: Path) -> None:
        lines = _run_main(
            tmp_path,
            [
                "-n",
                "4",
                "--images-per-request",
                "2",
                "--image-size",
                "32",
                "32",
                "--image-dir",
                str(tmp_path / "imgs"),
                "--seed",
                "1",
                "-o",
                str(tmp_path / "out.jsonl"),
            ],
        )
        assert len(lines) == 4
        for line in lines:
            assert "text" in line
            assert len(line["images"]) == 2
            assert "session_id" not in line


class TestSlidingWindow:
    """sliding-window produces causal sessions with image overlap."""

    def test_output_structure(self, tmp_path: Path) -> None:
        lines = _run_main(
            tmp_path,
            [
                "sliding-window",
                "--num-users",
                "2",
                "--turns-per-user",
                "3",
                "--window-size",
                "5",
                "--image-size",
                "32",
                "32",
                "--image-dir",
                str(tmp_path / "imgs"),
                "--seed",
                "42",
                "-o",
                str(tmp_path / "sw.jsonl"),
            ],
        )

        assert len(lines) == 6  # 2 users x 3 turns

        # Round-robin interleaving: user_0, user_1, user_0, user_1, ...
        session_ids = [row["session_id"] for row in lines]
        assert session_ids == [
            "user_0",
            "user_1",
            "user_0",
            "user_1",
            "user_0",
            "user_1",
        ]

        for line in lines:
            assert len(line["images"]) == 5

    def test_image_overlap(self, tmp_path: Path) -> None:
        lines = _run_main(
            tmp_path,
            [
                "sliding-window",
                "--num-users",
                "1",
                "--turns-per-user",
                "3",
                "--window-size",
                "4",
                "--image-size",
                "32",
                "32",
                "--image-dir",
                str(tmp_path / "imgs"),
                "--seed",
                "7",
                "-o",
                str(tmp_path / "overlap.jsonl"),
            ],
        )

        assert len(lines) == 3
        # Consecutive turns should share window_size-1 images
        for i in range(len(lines) - 1):
            prev = lines[i]["images"]
            curr = lines[i + 1]["images"]
            # Sliding by 1: prev[1:] == curr[:-1]
            assert (
                prev[1:] == curr[:-1]
            ), f"Turn {i} and {i + 1} should share 3/4 images"