cambricon_mlu.py 594 Bytes
Newer Older
xuwx1's avatar
xuwx1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import torch
import torch.distributed as dist

from lightx2v_platform.registry_factory import PLATFORM_DEVICE_REGISTER


@PLATFORM_DEVICE_REGISTER("mlu")
class MluDevice:
    name = "mlu"

    @staticmethod
    def is_available() -> bool:
        try:
            import torch_mlu

            return torch_mlu.mlu.is_available()
        except ImportError:
            return False

    @staticmethod
    def get_device() -> str:
        return "mlu"

    @staticmethod
    def init_parallel_env():
        dist.init_process_group(backend="cncl")
        torch.mlu.set_device(dist.get_rank())