xpu.py 1.09 KB
Newer Older
1
2
import torch

3
4
5
6
7
from vllm.logger import init_logger

from .interface import DeviceCapability, Platform, PlatformEnum, _Backend

logger = init_logger(__name__)
8
9
10
11
12


class XPUPlatform(Platform):
    _enum = PlatformEnum.XPU

13
14
15
16
17
18
    @classmethod
    def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
        if selected_backend != _Backend.IPEX:
            logger.info("Cannot use %s backend on XPU.", selected_backend)
        return _Backend.IPEX

19
20
    @staticmethod
    def get_device_capability(device_id: int = 0) -> DeviceCapability:
21
22
23
        major, minor, *_ = torch.xpu.get_device_capability(
            device_id)['version'].split('.')
        return DeviceCapability(major=int(major), minor=int(minor))
24
25
26
27

    @staticmethod
    def get_device_name(device_id: int = 0) -> str:
        return torch.xpu.get_device_name(device_id)
28
29
30
31
32

    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        device_props = torch.xpu.get_device_properties(device_id)
        return device_props.total_memory
33
34
35
36

    @staticmethod
    def inference_mode():
        return torch.no_grad()