base.py 1.4 KB
Newer Older
xuwx1's avatar
xuwx1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import os

from loguru import logger

from lightx2v_platform.base import global_var
from lightx2v_platform.registry_factory import PLATFORM_DEVICE_REGISTER


def init_ai_device(platform="cuda"):
    platform_device = PLATFORM_DEVICE_REGISTER.get(platform, None)
    if platform_device is None:
        available_platforms = list(PLATFORM_DEVICE_REGISTER.keys())
        raise RuntimeError(f"Unsupported platform: {platform}. Available platforms: {available_platforms}")
    global_var.AI_DEVICE = platform_device.get_device()
    logger.info(f"Initialized AI_DEVICE: {global_var.AI_DEVICE}")
    return global_var.AI_DEVICE


def check_ai_device(platform="cuda"):
    platform_device = PLATFORM_DEVICE_REGISTER.get(platform, None)
    if platform_device is None:
        available_platforms = list(PLATFORM_DEVICE_REGISTER.keys())
        raise RuntimeError(f"Unsupported platform: {platform}. Available platforms: {available_platforms}")
    is_available = platform_device.is_available()
    if not is_available:
        skip_platform_check = os.getenv("SKIP_PLATFORM_CHECK", "False") in ["1", "True"]
        error_msg = f"AI device for platform '{platform}' is not available. Please check your runtime environment."
        if skip_platform_check:
            logger.warning(error_msg)
            return True
        raise RuntimeError(error_msg)
    logger.info(f"AI device for platform '{platform}' is available.")
    return True