conftest.py 2.54 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.
"""conftest for tests/jax"""
5
import os
6
7
import jax
import pytest
8
9
from collections import defaultdict
import time
10

11
12
13

import transformer_engine.jax
from transformer_engine_jax import get_device_compute_capability
14

15

16
@pytest.fixture(autouse=True, scope="function")
17
18
19
20
21
22
23
def clear_live_arrays():
    """
    Clear all live arrays to keep the resource clean
    """
    yield
    for arr in jax.live_arrays():
        arr.delete()
24
25
26


@pytest.fixture(autouse=True, scope="module")
27
def enable_fused_attn_after_hopper():
28
29
30
31
32
33
34
35
36
    """
    Enable fused attn for hopper+ arch.
    Fused attn kernels on pre-hopper arch are not deterministic.
    """
    if get_device_compute_capability(0) >= 90:
        os.environ["NVTE_FUSED_ATTN"] = "1"
    yield
    if "NVTE_FUSED_ATTN" in os.environ:
        del os.environ["NVTE_FUSED_ATTN"]
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87


class TestTimingPlugin:
    """
    Plugin to measure test execution time. Enable test timing by setting NVTE_JAX_TEST_TIMING=1
    in the environment.
    """

    def __init__(self):
        self.test_timings = defaultdict(list)

    @pytest.hookimpl(tryfirst=True)
    def pytest_runtest_setup(self, item):
        item._timing_start = time.time()

    @pytest.hookimpl(trylast=True)
    def pytest_runtest_teardown(self, item, nextitem):
        if hasattr(item, "_timing_start"):
            duration = time.time() - item._timing_start

            # Extract base function name without parameters
            test_name = item.name
            if "[" in test_name:
                base_name = test_name.split("[")[0]
            else:
                base_name = test_name

            self.test_timings[base_name].append(duration)

    def pytest_sessionfinish(self, session, exitstatus):
        print("\n" + "=" * 80)
        print("TEST RUNTIME SUMMARY (grouped by function)")
        print("=" * 80)

        total_overall = 0
        for test_name, durations in sorted(self.test_timings.items()):
            total_time = sum(durations)
            count = len(durations)
            avg_time = total_time / count if count > 0 else 0
            total_overall += total_time

            print(f"{test_name:<60} | {count:3}x | {total_time:7.2f}s | avg: {avg_time:6.2f}s")

        print("=" * 80)
        print(f"{'TOTAL RUNTIME':<60} | {'':>3}  | {total_overall:7.2f}s |")
        print("=" * 80)


def pytest_configure(config):
    if os.getenv("NVTE_JAX_TEST_TIMING", "0") == "1":
        config.pluginmanager.register(TestTimingPlugin(), "test_timing")