dcu.py 1.48 KB
Newer Older
fuheaven's avatar
fuheaven committed
1
2
3
4
5
6
7
8
9
10
import torch
import torch.distributed as dist

from lightx2v_platform.registry_factory import PLATFORM_DEVICE_REGISTER


@PLATFORM_DEVICE_REGISTER("dcu")
class DcuDevice:
    """
    DCU (AMD GPU) Device implementation for LightX2V.
fuheaven's avatar
fuheaven committed
11

fuheaven's avatar
fuheaven committed
12
13
14
    DCU uses ROCm which provides CUDA-compatible APIs through HIP.
    Most PyTorch operations work transparently through the ROCm backend.
    """
fuheaven's avatar
fuheaven committed
15

fuheaven's avatar
fuheaven committed
16
17
18
19
20
21
    name = "dcu"

    @staticmethod
    def is_available() -> bool:
        """
        Check if DCU is available.
fuheaven's avatar
fuheaven committed
22

fuheaven's avatar
fuheaven committed
23
24
25
26
27
28
29
30
31
32
33
34
35
        DCU uses the standard CUDA API through ROCm's HIP compatibility layer.
        Returns:
            bool: True if DCU/CUDA is available
        """
        try:
            return torch.cuda.is_available()
        except ImportError:
            return False

    @staticmethod
    def get_device() -> str:
        """
        Get the device type string.
fuheaven's avatar
fuheaven committed
36

fuheaven's avatar
fuheaven committed
37
38
        Returns "cuda" because DCU uses CUDA-compatible APIs through ROCm.
        This allows seamless integration with existing PyTorch code.
fuheaven's avatar
fuheaven committed
39

fuheaven's avatar
fuheaven committed
40
41
42
43
44
45
46
47
48
        Returns:
            str: "cuda" for ROCm compatibility
        """
        return "cuda"

    @staticmethod
    def init_parallel_env():
        """
        Initialize distributed parallel environment for DCU.
fuheaven's avatar
fuheaven committed
49

fuheaven's avatar
fuheaven committed
50
51
52
53
54
55
        Uses RCCL (ROCm Collective Communications Library) which is
        compatible with NCCL APIs for multi-GPU communication.
        """
        # RCCL is compatible with NCCL backend
        dist.init_process_group(backend="nccl")
        torch.cuda.set_device(dist.get_rank())