tpu.py 428 Bytes
Newer Older
1
2
3
4
5
6
7
8
import torch

from .interface import Platform, PlatformEnum


class TpuPlatform(Platform):
    _enum = PlatformEnum.TPU

9
10
11
    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
        raise NotImplementedError
zhuwenwen's avatar
zhuwenwen committed
12
13
14
15
    
    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        raise NotImplementedError
16
17
18

    @classmethod
    def inference_mode(cls):
19
        return torch.no_grad()