hpu.py 365 Bytes
Newer Older
1
2
import torch

3
from .interface import Platform, PlatformEnum, _Backend
4
5
6
7


class HpuPlatform(Platform):
    _enum = PlatformEnum.HPU
8
    device_type: str = "hpu"
9

10
11
12
13
    @classmethod
    def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
        return _Backend.HPU_ATTN

14
15
16
    @staticmethod
    def inference_mode():
        return torch.no_grad()