helpers.py 3.13 KB
Newer Older
1
import functools
2
from io import BytesIO
Aarni Koskela's avatar
Aarni Koskela committed
3
from itertools import product
4
import os
Aarni Koskela's avatar
Aarni Koskela committed
5
import random
6
from typing import Any
Aarni Koskela's avatar
Aarni Koskela committed
7
8
9
10
11
12

import torch

test_dims_rng = random.Random(42)


13
14
15
16
17
TRUE_FALSE = (True, False)
BOOLEAN_TRIPLES = list(product(TRUE_FALSE, repeat=3))  # all combinations of (bool, bool, bool)
BOOLEAN_TUPLES = list(product(TRUE_FALSE, repeat=2))  # all combinations of (bool, bool)


18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
@functools.cache
def get_available_devices():
    if "BNB_TEST_DEVICE" in os.environ:
        # If the environment variable is set, use it directly.
        return [os.environ["BNB_TEST_DEVICE"]]

    devices = ["cpu"]

    if hasattr(torch, "accelerator"):
        # PyTorch 2.6+ - determine accelerator using agnostic API.
        if torch.accelerator.is_available():
            devices += [str(torch.accelerator.current_accelerator())]
    else:
        if torch.cuda.is_available():
            devices += ["cuda"]

        if torch.backends.mps.is_available():
            devices += ["mps"]

        if hasattr(torch, "xpu") and torch.xpu.is_available():
            devices += ["xpu"]

        custom_backend_name = torch._C._get_privateuse1_backend_name()
        custom_backend_module = getattr(torch, custom_backend_name, None)
        custom_backend_is_available_fn = getattr(custom_backend_module, "is_available", None)

        if custom_backend_is_available_fn and custom_backend_module.is_available():
            devices += [custom_backend_name]

    return devices


50
51
52
53
54
55
56
57
58
def torch_save_to_buffer(obj):
    buffer = BytesIO()
    torch.save(obj, buffer)
    buffer.seek(0)
    return buffer


def torch_load_from_buffer(buffer):
    buffer.seek(0)
59
    obj = torch.load(buffer, weights_only=False)
60
61
62
63
    buffer.seek(0)
    return obj


64
def get_test_dims(min: int, max: int, *, n: int) -> list[int]:
Aarni Koskela's avatar
Aarni Koskela committed
65
66
67
68
69
70
71
72
    return [test_dims_rng.randint(min, max) for _ in range(n)]


def format_with_label(label: str, value: Any) -> str:
    if isinstance(value, bool):
        formatted = "T" if value else "F"
    elif isinstance(value, (list, tuple)) and all(isinstance(v, bool) for v in value):
        formatted = "".join("T" if b else "F" for b in value)
73
74
    elif isinstance(value, torch.dtype):
        formatted = describe_dtype(value)
Aarni Koskela's avatar
Aarni Koskela committed
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
    else:
        formatted = str(value)
    return f"{label}={formatted}"


def id_formatter(label: str):
    """
    Return a function that formats the value given to it with the given label.
    """
    return lambda value: format_with_label(label, value)


DTYPE_NAMES = {
    torch.bfloat16: "bf16",
    torch.bool: "bool",
    torch.float16: "fp16",
    torch.float32: "fp32",
    torch.float64: "fp64",
    torch.int32: "int32",
    torch.int64: "int64",
    torch.int8: "int8",
}


def describe_dtype(dtype: torch.dtype) -> str:
    return DTYPE_NAMES.get(dtype) or str(dtype).rpartition(".")[2]
101
102
103
104
105
106
107
108
109
110
111


def is_supported_on_hpu(
    quant_type: str = "nf4", dtype: torch.dtype = torch.bfloat16, quant_storage: torch.dtype = torch.uint8
) -> bool:
    """
    Check if the given quant_type, dtype and quant_storage are supported on HPU.
    """
    if quant_type == "fp4" or dtype == torch.float16 or quant_storage not in (torch.uint8, torch.bfloat16):
        return False
    return True