"benchmark/json_decode_regex/README.md" did not exist on "f652494df16ef9fa0fac998ddf63961aee0849d4"
base.py 1.18 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
from loguru import logger

from lightx2v_platform.base import global_var
from lightx2v_platform.registry_factory import PLATFORM_DEVICE_REGISTER


def init_ai_device(platform="cuda"):
    platform_device = PLATFORM_DEVICE_REGISTER.get(platform, None)
    if platform_device is None:
        available_platforms = list(PLATFORM_DEVICE_REGISTER.keys())
        raise RuntimeError(f"Unsupported platform: {platform}. Available platforms: {available_platforms}")
    global_var.AI_DEVICE = platform_device.get_device()
    logger.info(f"Initialized AI_DEVICE: {global_var.AI_DEVICE}")
    return global_var.AI_DEVICE


def check_ai_device(platform="cuda"):
    platform_device = PLATFORM_DEVICE_REGISTER.get(platform, None)
    if platform_device is None:
        available_platforms = list(PLATFORM_DEVICE_REGISTER.keys())
        raise RuntimeError(f"Unsupported platform: {platform}. Available platforms: {available_platforms}")
    is_available = platform_device.is_available()
    if not is_available:
        raise RuntimeError(f"AI device for platform '{platform}' is not available. Please check your runtime environment.")
    logger.info(f"AI device for platform '{platform}' is available.")
    return True