xpu.py 3.06 KB
Newer Older
1
from typing import TYPE_CHECKING, Optional
2

3
4
import torch

5
6
7
8
from vllm.logger import init_logger

from .interface import DeviceCapability, Platform, PlatformEnum, _Backend

9
10
11
12
13
if TYPE_CHECKING:
    from vllm.config import VllmConfig
else:
    VllmConfig = None

14
logger = init_logger(__name__)
15
16
17
18


class XPUPlatform(Platform):
    _enum = PlatformEnum.XPU
19
    device_name: str = "xpu"
20
    device_type: str = "xpu"
21
    dispatch_key: str = "XPU"
22

23
24
25
26
27
28
    @classmethod
    def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
        if selected_backend != _Backend.IPEX:
            logger.info("Cannot use %s backend on XPU.", selected_backend)
        return _Backend.IPEX

29
30
    @staticmethod
    def get_device_capability(device_id: int = 0) -> DeviceCapability:
31
32
33
        major, minor, *_ = torch.xpu.get_device_capability(
            device_id)['version'].split('.')
        return DeviceCapability(major=int(major), minor=int(minor))
34
35
36
37

    @staticmethod
    def get_device_name(device_id: int = 0) -> str:
        return torch.xpu.get_device_name(device_id)
38
39
40
41
42

    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        device_props = torch.xpu.get_device_properties(device_id)
        return device_props.total_memory
43

44
45
46
47
    @classmethod
    def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
        return True

48
49
50
    @staticmethod
    def inference_mode():
        return torch.no_grad()
51
52
53

    @classmethod
    def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
54
55
56
57
        cache_config = vllm_config.cache_config
        if cache_config and cache_config.block_size is None:
            cache_config.block_size = 16

58
59
60
61
62
63
64
65
66
67
68
        # check and update model config
        model_config = vllm_config.model_config
        if model_config.dtype == torch.bfloat16:
            logger.warning(
                "bfloat16 is not fully supported on XPU, casting to float16.")
            model_config.dtype = torch.float16
        if not model_config.enforce_eager:
            logger.warning(
                "CUDA graph is not supported on XPU, fallback to the eager "
                "mode.")
            model_config.enforce_eager = True
69

70
71
72
73
        if vllm_config.speculative_config is not None:
            raise NotImplementedError(
                "XPU does not support speculative decoding")

74
75
76
77
78
79
80
81
82
        # check and update parallel config
        parallel_config = vllm_config.parallel_config
        if (parallel_config.distributed_executor_backend is not None
                and parallel_config.distributed_executor_backend != "ray"):
            logger.warning(
                "%s is not supported on XPU, fallback to ray distributed"
                " executor backend.",
                parallel_config.distributed_executor_backend)
            parallel_config.distributed_executor_backend = "ray"
83
84
        if parallel_config.worker_cls == "auto":
            parallel_config.worker_cls = "vllm.worker.xpu_worker.XPUWorker"
85
86
87
88
89

    @classmethod
    def is_pin_memory_available(cls):
        logger.warning("Pin memory is not supported on XPU.")
        return False