conftest.py 2.82 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
# SPDX-License-Identifier: MIT

import json
import threading

import pytest

_JSONL_WRITE_LOCK = threading.Lock()


def _make_param_key(params: dict) -> str:
    # Keep deterministic order like TileKernels helper.
    return ",".join(f"{k}={params[k]}" for k in sorted(params.keys()))


def pytest_addoption(parser):
    parser.addoption(
        "--run-benchmark",
        action="store_true",
        default=False,
        help="Run benchmark tests (skipped by default)",
    )
    parser.addoption(
        "--benchmark-output",
        default=None,
        help="Path to write benchmark results as JSONL (one JSON object per line)",
    )


def pytest_configure(config):
    config.addinivalue_line("markers", "benchmark: mark test as benchmark (skip by default)")


def pytest_collection_modifyitems(config, items):
    if not config.getoption("--run-benchmark"):
        skip_bench = pytest.mark.skip(reason="need --run-benchmark to run")
        for item in items:
            if "benchmark" in item.keywords:
                item.add_marker(skip_bench)


@pytest.fixture
def benchmark_record(request):
    output_path = request.config.getoption("--benchmark-output")

    def _record(*, kernel, operation, params, time_us, bandwidth_gbs=None, extras=None):
        if params:
            param_str = _make_param_key(params)
            key = f"{kernel}/{operation}[{param_str}]"
        else:
            key = f"{kernel}/{operation}"

        parts = [f"  BENCH {key}: {time_us:.1f} us"]
        if bandwidth_gbs is not None:
            parts.append(f", bandwidth_gbs={bandwidth_gbs:.2f}")
        if extras:
            for ek, ev in extras.items():
                if isinstance(ev, float):
                    parts.append(f", {ek}={ev:.2f}")
                else:
                    parts.append(f", {ek}={ev}")
        print("".join(parts))

        if output_path:
            record = {
                "kernel": kernel,
                "operation": operation,
                "params": dict(sorted(params.items())) if params else params,
                "time_us": round(time_us, 2),
            }
            if bandwidth_gbs is not None:
                record["bandwidth_gbs"] = round(bandwidth_gbs, 4)
            if extras:
                record["extras"] = {
                    k: round(v, 4) if isinstance(v, float) else v for k, v in extras.items()
                }
            line = json.dumps(record, ensure_ascii=False)
            with _JSONL_WRITE_LOCK:
                with open(output_path, "a") as f:
                    f.write(line + "\n")

    return _record


@pytest.fixture
def benchmark_timer():
    from tilelang.profiler.bench import do_bench

    def _timer(fn, **overrides):
        kwargs = dict(backend="cupti", warmup=0, rep=30)
        kwargs.update(overrides)
        return do_bench(fn, **kwargs) * 1e3  # ms -> us

    return _timer