cpu.py 294 Bytes
Newer Older
1
2
3
4
5
6
7
8
import torch

from .interface import Platform, PlatformEnum


class CpuPlatform(Platform):
    _enum = PlatformEnum.CPU

9
10
    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
11
12
        return "cpu"

13
14
    @classmethod
    def inference_mode(cls):
15
        return torch.no_grad()