common.py 1.14 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
#
# See LICENSE for license information.
"""Shared functions for the encoder tests"""
from functools import lru_cache

7
import transformer_engine
8
from transformer_engine_jax import get_device_compute_capability
9
from transformer_engine.common import recipe
10
11
12
13
14
15
16


@lru_cache
def is_bf16_supported():
    """Return if BF16 has hardware supported"""
    gpu_arch = get_device_compute_capability(0)
    return gpu_arch >= 80
17
18
19
20
21
22
23


@lru_cache
def is_fp8_supported():
    """Return if FP8 has hardware supported"""
    gpu_arch = get_device_compute_capability(0)
    return gpu_arch >= 90
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41


@lru_cache
def is_mxfp8_supported():
    """Return if FP8 has hardware supported"""
    gpu_arch = get_device_compute_capability(0)
    return gpu_arch >= 100


def get_fp8_recipe_from_name_string(name: str):
    """Query recipe from a given name string"""
    match name:
        case "DelayedScaling":
            return recipe.DelayedScaling()
        case "MXFP8BlockScaling":
            return recipe.MXFP8BlockScaling()
        case _:
            raise ValueError(f"Invalid fp8_recipe, got {name}")