tpu.py 196 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
11
import torch

from .interface import Platform, PlatformEnum


class TpuPlatform(Platform):
    _enum = PlatformEnum.TPU

    @staticmethod
    def inference_mode():
        return torch.no_grad()