rocm.py 504 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
from functools import lru_cache
from typing import Tuple

import torch

from .interface import Platform, PlatformEnum


class RocmPlatform(Platform):
    _enum = PlatformEnum.ROCM

    @staticmethod
    @lru_cache(maxsize=8)
    def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
        return torch.cuda.get_device_capability(device_id)
16
17
18
19
20

    @staticmethod
    @lru_cache(maxsize=8)
    def get_device_name(device_id: int = 0) -> str:
        return torch.cuda.get_device_name(device_id)