hpu.py 336 Bytes
Newer Older
1
2
import torch

3
from .interface import Platform, PlatformEnum, _Backend
4
5
6
7
8


class HpuPlatform(Platform):
    _enum = PlatformEnum.HPU

9
10
11
12
    @classmethod
    def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
        return _Backend.HPU_ATTN

13
14
15
    @staticmethod
    def inference_mode():
        return torch.no_grad()