__init__.py 331 Bytes
Newer Older
xuwx1's avatar
xuwx1 committed
1
2
3
4
5
6
7
8
9
10
11
12
import os

from lightx2v_platform.base.global_var import AI_DEVICE

if AI_DEVICE == "mlu":
    from .attn.cambricon_mlu import *
    from .mm.cambricon_mlu import *
elif AI_DEVICE == "cuda":
    # Check if running on DCU platform
    if os.getenv("PLATFORM") == "dcu":
        from .attn.dcu import *
        from .mm.dcu import *