__init__.py 331 Bytes
Newer Older
fuheaven's avatar
fuheaven committed
1
2
import os

3
4
5
6
7
from lightx2v_platform.base.global_var import AI_DEVICE

if AI_DEVICE == "mlu":
    from .attn.cambricon_mlu import *
    from .mm.cambricon_mlu import *
fuheaven's avatar
fuheaven committed
8
9
10
11
12
elif AI_DEVICE == "cuda":
    # Check if running on DCU platform
    if os.getenv("PLATFORM") == "dcu":
        from .attn.dcu import *
        from .mm.dcu import *