set_ai_device.py 277 Bytes
Newer Older
xuwx1's avatar
xuwx1 committed
1
2
3
4
5
6
import os

from lightx2v_platform import *


def set_ai_device():
xuwx1's avatar
xuwx1 committed
7
    platform = "cuda"
xuwx1's avatar
xuwx1 committed
8
9
10
11
12
13
14
15
    init_ai_device(platform)
    from lightx2v_platform.base.global_var import AI_DEVICE

    check_ai_device(AI_DEVICE)


set_ai_device()
from lightx2v_platform.ops import *  # noqa: E402