conftest.py 1.02 KB
Newer Older
1
2
3
4
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""conftest for tests/jax"""
5
import os
6
7
8
import jax
import pytest

9
10
from transformer_engine.transformer_engine_jax import get_device_compute_capability

11

12
@pytest.fixture(autouse=True, scope="function")
13
14
15
16
17
18
19
def clear_live_arrays():
    """
    Clear all live arrays to keep the resource clean
    """
    yield
    for arr in jax.live_arrays():
        arr.delete()
20
21
22


@pytest.fixture(autouse=True, scope="module")
23
def enable_fused_attn_after_hopper():
24
25
26
27
28
29
30
31
32
33
34
35
    """
    Enable fused attn for hopper+ arch.
    Fused attn kernels on pre-hopper arch are not deterministic.
    """
    if get_device_compute_capability(0) >= 90:
        os.environ["NVTE_FUSED_ATTN"] = "1"
        os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "0"
    yield
    if "NVTE_FUSED_ATTN" in os.environ:
        del os.environ["NVTE_FUSED_ATTN"]
    if "NVTE_ALLOW_NONDETERMINISTIC_ALGO" in os.environ:
        del os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"]