cpu.py 436 Bytes
Newer Older
1
import psutil
2
3
4
5
6
7
8
9
import torch

from .interface import Platform, PlatformEnum


class CpuPlatform(Platform):
    _enum = PlatformEnum.CPU

10
11
    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
12
13
        return "cpu"

14
15
16
17
    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        return psutil.virtual_memory().total

18
19
    @classmethod
    def inference_mode(cls):
20
        return torch.no_grad()