helpers.py 3.3 KB
Newer Older
1
import functools
2
from io import BytesIO
Aarni Koskela's avatar
Aarni Koskela committed
3
from itertools import product
4
import os
Aarni Koskela's avatar
Aarni Koskela committed
5
import random
6
from typing import Any
Aarni Koskela's avatar
Aarni Koskela committed
7
8
9

import torch

10
11
from bitsandbytes.cextension import HIP_ENVIRONMENT

Aarni Koskela's avatar
Aarni Koskela committed
12
13
14
test_dims_rng = random.Random(42)


15
16
17
18
19
TRUE_FALSE = (True, False)
BOOLEAN_TRIPLES = list(product(TRUE_FALSE, repeat=3))  # all combinations of (bool, bool, bool)
BOOLEAN_TUPLES = list(product(TRUE_FALSE, repeat=2))  # all combinations of (bool, bool)


20
@functools.cache
Egor Krivov's avatar
Egor Krivov committed
21
def get_available_devices(no_cpu=False):
22
23
    if "BNB_TEST_DEVICE" in os.environ:
        # If the environment variable is set, use it directly.
Matthew Douglas's avatar
Matthew Douglas committed
24
25
        device = os.environ["BNB_TEST_DEVICE"]
        return [] if no_cpu and device == "cpu" else [device]
26

Egor Krivov's avatar
Egor Krivov committed
27
    devices = [] if HIP_ENVIRONMENT else ["cpu"] if not no_cpu else []
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52

    if hasattr(torch, "accelerator"):
        # PyTorch 2.6+ - determine accelerator using agnostic API.
        if torch.accelerator.is_available():
            devices += [str(torch.accelerator.current_accelerator())]
    else:
        if torch.cuda.is_available():
            devices += ["cuda"]

        if torch.backends.mps.is_available():
            devices += ["mps"]

        if hasattr(torch, "xpu") and torch.xpu.is_available():
            devices += ["xpu"]

        custom_backend_name = torch._C._get_privateuse1_backend_name()
        custom_backend_module = getattr(torch, custom_backend_name, None)
        custom_backend_is_available_fn = getattr(custom_backend_module, "is_available", None)

        if custom_backend_is_available_fn and custom_backend_module.is_available():
            devices += [custom_backend_name]

    return devices


53
54
55
56
57
58
59
60
61
def torch_save_to_buffer(obj):
    buffer = BytesIO()
    torch.save(obj, buffer)
    buffer.seek(0)
    return buffer


def torch_load_from_buffer(buffer):
    buffer.seek(0)
62
    obj = torch.load(buffer, weights_only=False)
63
64
65
66
    buffer.seek(0)
    return obj


67
def get_test_dims(min: int, max: int, *, n: int) -> list[int]:
Aarni Koskela's avatar
Aarni Koskela committed
68
69
70
71
72
73
74
75
    return [test_dims_rng.randint(min, max) for _ in range(n)]


def format_with_label(label: str, value: Any) -> str:
    if isinstance(value, bool):
        formatted = "T" if value else "F"
    elif isinstance(value, (list, tuple)) and all(isinstance(v, bool) for v in value):
        formatted = "".join("T" if b else "F" for b in value)
76
77
    elif isinstance(value, torch.dtype):
        formatted = describe_dtype(value)
Aarni Koskela's avatar
Aarni Koskela committed
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
    else:
        formatted = str(value)
    return f"{label}={formatted}"


def id_formatter(label: str):
    """
    Return a function that formats the value given to it with the given label.
    """
    return lambda value: format_with_label(label, value)


DTYPE_NAMES = {
    torch.bfloat16: "bf16",
    torch.bool: "bool",
    torch.float16: "fp16",
    torch.float32: "fp32",
    torch.float64: "fp64",
    torch.int32: "int32",
    torch.int64: "int64",
    torch.int8: "int8",
}


def describe_dtype(dtype: torch.dtype) -> str:
    return DTYPE_NAMES.get(dtype) or str(dtype).rpartition(".")[2]
104
105
106
107
108
109
110
111
112
113
114


def is_supported_on_hpu(
    quant_type: str = "nf4", dtype: torch.dtype = torch.bfloat16, quant_storage: torch.dtype = torch.uint8
) -> bool:
    """
    Check if the given quant_type, dtype and quant_storage are supported on HPU.
    """
    if quant_type == "fp4" or dtype == torch.float16 or quant_storage not in (torch.uint8, torch.bfloat16):
        return False
    return True