# SPDX-License-Identifier: MIT import json import threading import pytest _JSONL_WRITE_LOCK = threading.Lock() def _make_param_key(params: dict) -> str: # Keep deterministic order like TileKernels helper. return ",".join(f"{k}={params[k]}" for k in sorted(params.keys())) def pytest_addoption(parser): parser.addoption( "--run-benchmark", action="store_true", default=False, help="Run benchmark tests (skipped by default)", ) parser.addoption( "--benchmark-output", default=None, help="Path to write benchmark results as JSONL (one JSON object per line)", ) def pytest_configure(config): config.addinivalue_line("markers", "benchmark: mark test as benchmark (skip by default)") def pytest_collection_modifyitems(config, items): if not config.getoption("--run-benchmark"): skip_bench = pytest.mark.skip(reason="need --run-benchmark to run") for item in items: if "benchmark" in item.keywords: item.add_marker(skip_bench) @pytest.fixture def benchmark_record(request): output_path = request.config.getoption("--benchmark-output") def _record(*, kernel, operation, params, time_us, bandwidth_gbs=None, extras=None): if params: param_str = _make_param_key(params) key = f"{kernel}/{operation}[{param_str}]" else: key = f"{kernel}/{operation}" parts = [f" BENCH {key}: {time_us:.1f} us"] if bandwidth_gbs is not None: parts.append(f", bandwidth_gbs={bandwidth_gbs:.2f}") if extras: for ek, ev in extras.items(): if isinstance(ev, float): parts.append(f", {ek}={ev:.2f}") else: parts.append(f", {ek}={ev}") print("".join(parts)) if output_path: record = { "kernel": kernel, "operation": operation, "params": dict(sorted(params.items())) if params else params, "time_us": round(time_us, 2), } if bandwidth_gbs is not None: record["bandwidth_gbs"] = round(bandwidth_gbs, 4) if extras: record["extras"] = { k: round(v, 4) if isinstance(v, float) else v for k, v in extras.items() } line = json.dumps(record, ensure_ascii=False) with _JSONL_WRITE_LOCK: with open(output_path, "a") as f: f.write(line + "\n") return _record @pytest.fixture def benchmark_timer(): from tilelang.profiler.bench import do_bench def _timer(fn, **overrides): kwargs = dict(backend="cupti", warmup=0, rep=30) kwargs.update(overrides) return do_bench(fn, **kwargs) * 1e3 # ms -> us return _timer