import_utils.py 2.27 KB
Newer Older
jixx's avatar
init  
jixx committed
1
2
3
4
5
import torch
from loguru import logger
import os


jixx's avatar
jixx committed
6
7
8
import importlib.util


jixx's avatar
init  
jixx committed
9
def is_ipex_available():
jixx's avatar
jixx committed
10
    return importlib.util.find_spec("intel_extension_for_pytorch") is not None
jixx's avatar
init  
jixx committed
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58


def get_cuda_free_memory(device, memory_fraction):
    total_free_memory, _ = torch.cuda.mem_get_info(device)
    total_gpu_memory = torch.cuda.get_device_properties(device).total_memory
    free_memory = max(0, total_free_memory - (1 - memory_fraction) * total_gpu_memory)
    return free_memory


def get_xpu_free_memory(device, memory_fraction):
    total_memory = torch.xpu.get_device_properties(device).total_memory
    device_id = device.index
    memory_fraction = float(os.getenv("XPU_MEMORY_FRACTION", "1.0"))
    free_memory = max(
        0,
        int(
            total_memory * 0.9 * memory_fraction - torch.xpu.memory_reserved(device_id)
        ),
    )
    return free_memory


def get_cpu_free_memory(device, memory_fraction):
    import psutil
    from text_generation_server.utils.dist import WORLD_SIZE

    mem = psutil.virtual_memory()
    free_memory = int(mem.available * 0.95 / WORLD_SIZE)
    return free_memory


def noop(*args, **kwargs):
    pass


SYSTEM = None
if torch.version.hip is not None:
    SYSTEM = "rocm"
    empty_cache = torch.cuda.empty_cache
    synchronize = torch.cuda.synchronize
    get_free_memory = get_cuda_free_memory
elif torch.version.cuda is not None and torch.cuda.is_available():
    SYSTEM = "cuda"
    empty_cache = torch.cuda.empty_cache
    synchronize = torch.cuda.synchronize
    get_free_memory = get_cuda_free_memory
elif is_ipex_available():
    SYSTEM = "ipex"
jixx's avatar
jixx committed
59
60
    import intel_extension_for_pytorch  # noqa: F401

jixx's avatar
init  
jixx committed
61
62
63
64
65
66
67
68
    if hasattr(torch, "xpu") and torch.xpu.is_available():
        empty_cache = torch.xpu.empty_cache
        synchronize = torch.xpu.synchronize
        get_free_memory = get_xpu_free_memory
    else:
        empty_cache = noop
        synchronize = noop
        get_free_memory = get_cpu_free_memory
jixx's avatar
jixx committed
69
70
71
72
73
elif hasattr(torch, "xpu") and torch.xpu.is_available():
    SYSTEM = "xpu"
    empty_cache = torch.xpu.empty_cache
    synchronize = torch.xpu.synchronize
    get_free_memory = get_xpu_free_memory
jixx's avatar
init  
jixx committed
74
75
76
77
78
79
80
else:
    SYSTEM = "cpu"

    empty_cache = noop
    synchronize = noop
    get_free_memory = get_cpu_free_memory
logger.info(f"Detected system {SYSTEM}")