dcu.py 1.48 KB
Newer Older
xuwx1's avatar
xuwx1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import torch
import torch.distributed as dist

from lightx2v_platform.registry_factory import PLATFORM_DEVICE_REGISTER


@PLATFORM_DEVICE_REGISTER("dcu")
class DcuDevice:
    """
    DCU (AMD GPU) Device implementation for LightX2V.

    DCU uses ROCm which provides CUDA-compatible APIs through HIP.
    Most PyTorch operations work transparently through the ROCm backend.
    """

    name = "dcu"

    @staticmethod
    def is_available() -> bool:
        """
        Check if DCU is available.

        DCU uses the standard CUDA API through ROCm's HIP compatibility layer.
        Returns:
            bool: True if DCU/CUDA is available
        """
        try:
            return torch.cuda.is_available()
        except ImportError:
            return False

    @staticmethod
    def get_device() -> str:
        """
        Get the device type string.

        Returns "cuda" because DCU uses CUDA-compatible APIs through ROCm.
        This allows seamless integration with existing PyTorch code.

        Returns:
            str: "cuda" for ROCm compatibility
        """
        return "cuda"

    @staticmethod
    def init_parallel_env():
        """
        Initialize distributed parallel environment for DCU.

        Uses RCCL (ROCm Collective Communications Library) which is
        compatible with NCCL APIs for multi-GPU communication.
        """
        # RCCL is compatible with NCCL backend
        dist.init_process_group(backend="nccl")
        torch.cuda.set_device(dist.get_rank())