device_utils.py 945 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

"""
Device utility functions for JAX quantization.

This module provides utility functions for checking device capabilities and compatibility
for quantization operations in JAX.
"""

import functools

import transformer_engine_jax

__all__ = [
    "get_device_compute_capability",
    "is_fp8_gemm_with_all_layouts_supported",
]


@functools.lru_cache(maxsize=None)
def get_device_compute_capability(gpu_id: int = 0) -> int:
    """
    Get the compute capability of the device.
    """
    return transformer_engine_jax.get_device_compute_capability(gpu_id)


@functools.lru_cache(maxsize=None)
def is_fp8_gemm_with_all_layouts_supported() -> bool:
    """Return True if using Blackwell architecture, False otherwise."""
    compute_capability = get_device_compute_capability()
    return 100 <= compute_capability < 120