cpu.py 440 Bytes
Newer Older
zhuwenwen's avatar
zhuwenwen committed
1
import psutil
2
3
4
5
6
7
8
9
import torch

from .interface import Platform, PlatformEnum


class CpuPlatform(Platform):
    _enum = PlatformEnum.CPU

10
11
    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
12
        return "cpu"
zhuwenwen's avatar
zhuwenwen committed
13
14
15
16
    
    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        return psutil.virtual_memory().total
17

18
19
    @classmethod
    def inference_mode(cls):
20
        return torch.no_grad()