xpu.py 5.19 KB
Newer Older
1
from typing import TYPE_CHECKING, Optional
2

3
4
import torch

5
6
7
8
from vllm.logger import init_logger

from .interface import DeviceCapability, Platform, PlatformEnum, _Backend

9
10
11
12
13
if TYPE_CHECKING:
    from vllm.config import VllmConfig
else:
    VllmConfig = None

14
logger = init_logger(__name__)
15
16
17
18


class XPUPlatform(Platform):
    _enum = PlatformEnum.XPU
19
    device_name: str = "xpu"
20
    device_type: str = "xpu"
21
    dispatch_key: str = "XPU"
22
23
24
    # Intel XPU's device key is "GPU" for Ray.
    # see https://github.com/ray-project/ray/blob/6a5eb5865eeb9ccf058a79b44f107e327e360673/python/ray/_private/accelerators/intel_gpu.py#L20 # noqa: E501
    ray_device_key: str = "GPU"
25
    device_control_env_var: str = "ONEAPI_DEVICE_SELECTOR"
26

27
    @classmethod
28
29
    def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
                             dtype: torch.dtype, kv_cache_dtype: Optional[str],
30
31
                             block_size: int, use_v1: bool,
                             use_mla: bool) -> str:
32
33
        if selected_backend != _Backend.IPEX:
            logger.info("Cannot use %s backend on XPU.", selected_backend)
34
35
        logger.info("Using IPEX attention backend.")
        return "vllm.attention.backends.ipex_attn.IpexAttnBackend"
36

37
38
    @staticmethod
    def get_device_capability(device_id: int = 0) -> DeviceCapability:
39
40
41
        major, minor, *_ = torch.xpu.get_device_capability(
            device_id)['version'].split('.')
        return DeviceCapability(major=int(major), minor=int(minor))
42
43
44
45

    @staticmethod
    def get_device_name(device_id: int = 0) -> str:
        return torch.xpu.get_device_name(device_id)
46
47
48
49
50

    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        device_props = torch.xpu.get_device_properties(device_id)
        return device_props.total_memory
51

52
53
54
55
    @classmethod
    def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
        return True

56
57
58
    @staticmethod
    def inference_mode():
        return torch.no_grad()
59
60
61

    @classmethod
    def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
62
63
64
65
        cache_config = vllm_config.cache_config
        if cache_config and cache_config.block_size is None:
            cache_config.block_size = 16

66
67
68
        # check and update model config
        model_config = vllm_config.model_config
        if model_config.dtype == torch.bfloat16:
69
70
71
72
73
74
75
76
            bf16_supported = cls.device_support_bf16()
            if not bf16_supported:
                logger.warning(
                    "bfloat16 is only supported on Intel Data Center GPU, "
                    "Intel Arc GPU is not supported yet. Your device is %s,"
                    "which is not supported. will fallback to float16",
                    cls.get_device_name())
                model_config.dtype = torch.float16
77
78
79
80
81
        if not model_config.enforce_eager:
            logger.warning(
                "CUDA graph is not supported on XPU, fallback to the eager "
                "mode.")
            model_config.enforce_eager = True
82

83
84
85
86
        if vllm_config.speculative_config is not None:
            raise NotImplementedError(
                "XPU does not support speculative decoding")

87
88
89
        if vllm_config.device_config is not None:
            assert vllm_config.device_config.device_type == "xpu"

90
91
        # check and update parallel config
        parallel_config = vllm_config.parallel_config
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
        if parallel_config.worker_cls == "auto":
            parallel_config.worker_cls = "vllm.worker.xpu_worker.XPUWorker"

        if parallel_config.distributed_executor_backend is None:
            parallel_config.distributed_executor_backend = "ray"
        elif parallel_config.distributed_executor_backend == "mp":
            # FIXME(kunshang):
            # spawn needs calling `if __name__ == '__main__':``
            # fork is not supported for xpu start new process.
            logger.error(
                "Both start methods (spawn and fork) have issue "
                "on XPU if you use mp backend, setting it to ray instead.")
            parallel_config.distributed_executor_backend = "ray"

        elif parallel_config.distributed_executor_backend != "ray":
107
108
109
110
111
            logger.warning(
                "%s is not supported on XPU, fallback to ray distributed"
                " executor backend.",
                parallel_config.distributed_executor_backend)
            parallel_config.distributed_executor_backend = "ray"
112
113
114
115
116

    @classmethod
    def is_pin_memory_available(cls):
        logger.warning("Pin memory is not supported on XPU.")
        return False
117
118
119
120
121
122
123

    @classmethod
    def get_current_memory_usage(cls,
                                 device: Optional[torch.types.Device] = None
                                 ) -> float:
        torch.xpu.reset_peak_memory_stats(device)
        return torch.xpu.max_memory_allocated(device)
124
125
126
127
128
129
130
131
132
133
134
135

    @classmethod
    def device_support_bf16(cls) -> bool:
        device_name = cls.get_device_name().lower()
        if device_name.count("arc") > 0:
            return False
        elif device_name.count("data center gpu") > 0:
            return True
        else:
            logger.warning("Unknown device name %s, always use float16",
                           device_name)
            return False