xpu.py 7.89 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import os
5
from typing import TYPE_CHECKING, Optional
6

7
8
import torch

9
import vllm.envs as envs
10
from vllm.logger import init_logger
11
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
12
13
14

from .interface import DeviceCapability, Platform, PlatformEnum, _Backend

15
if TYPE_CHECKING:
16
    from vllm.config import ModelConfig, VllmConfig
17
else:
18
    ModelConfig = None
19
20
    VllmConfig = None

21
logger = init_logger(__name__)
22
23
24
25


class XPUPlatform(Platform):
    _enum = PlatformEnum.XPU
26
    device_name: str = "xpu"
27
    device_type: str = "xpu"
28
    dispatch_key: str = "XPU"
29
30
31
    # Intel XPU's device key is "GPU" for Ray.
    # see https://github.com/ray-project/ray/blob/6a5eb5865eeb9ccf058a79b44f107e327e360673/python/ray/_private/accelerators/intel_gpu.py#L20 # noqa: E501
    ray_device_key: str = "GPU"
32
    dist_backend: str = "ccl"  # ccl | xccl
33
    device_control_env_var: str = "ONEAPI_DEVICE_SELECTOR"
34

35
    @classmethod
36
37
    def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
                             dtype: torch.dtype, kv_cache_dtype: Optional[str],
38
39
                             block_size: int, use_v1: bool,
                             use_mla: bool) -> str:
40
        if selected_backend is not None and selected_backend != _Backend.IPEX:
41
            logger.info("Cannot use %s backend on XPU.", selected_backend)
42
        use_v1 = envs.VLLM_USE_V1
43
44
45
46
47
48
        if use_v1:
            logger.info("Using Flash Attention backend on V1 engine.")
            return "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
        else:
            logger.info("Using IPEX attention backend.")
            return "vllm.attention.backends.ipex_attn.IpexAttnBackend"
49

50
51
52
53
54
55
56
    @classmethod
    def set_device(cls, device: torch.device) -> None:
        """
        Set the device for the current platform.
        """
        torch.xpu.set_device(device)

57
    @classmethod
58
    def get_device_capability(
59
60
61
        cls,
        device_id: int = 0,
    ) -> Optional[DeviceCapability]:
62
63
64
        # capacity format differs from cuda's and will cause unexpected
        # failure, so use None directly
        return None
65

66
67
    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
68
        return torch.xpu.get_device_name(device_id)
69

70
71
72
73
    @classmethod
    def get_punica_wrapper(cls) -> str:
        return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"

74
75
76
77
    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        device_props = torch.xpu.get_device_properties(device_id)
        return device_props.total_memory
78

79
80
81
82
    @classmethod
    def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
        return True

83
84
    @classmethod
    def inference_mode(cls):
85
        return torch.no_grad()
86
87
88

    @classmethod
    def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
89
        cache_config = vllm_config.cache_config
90
        model_config = vllm_config.model_config
91
        # in V1(or with ipex chunked prefill) block_size is 64
92
        if cache_config and cache_config.block_size is None:
93
94
95
96
            if envs.VLLM_USE_V1:
                cache_config.block_size = 64
            else:
                cache_config.block_size = 16
97

98
99
        # FIXME: Temporarily forcing eager mode
        # remove after t.compile support stabilizes.
100
        if (envs.VLLM_USE_V1 and model_config is not None
101
102
103
104
                and not vllm_config.model_config.enforce_eager):
            from vllm.config import CompilationLevel
            vllm_config.compilation_config.level = CompilationLevel.NO_COMPILATION  # noqa: E501

105
106
107
        # Instances created using VllmConfig() typically have model_config as
        # None by default. The modification involves adding a check to prevent
        # potential null exceptions check and update model config.
108
        if model_config is not None:
109
110
111
112
113
            if model_config.dtype == torch.bfloat16:
                bf16_supported = cls.device_support_bf16()
                if not bf16_supported:
                    model_config.dtype = torch.float16
            if not model_config.enforce_eager:
114
                logger.warning(
115
116
117
                    "CUDA graph is not supported on XPU, fallback to the eager "
                    "mode.")
                model_config.enforce_eager = True
118
119
120

        # check and update parallel config
        parallel_config = vllm_config.parallel_config
121
122
123
124
125
        if envs.VLLM_USE_V1:
            parallel_config.worker_cls =\
                "vllm.v1.worker.xpu_worker.XPUWorker"
        else:
            parallel_config.worker_cls = "vllm.worker.xpu_worker.XPUWorker"
126
127

        if parallel_config.distributed_executor_backend is None:
128
129
130
131
            if parallel_config.world_size > 1:
                parallel_config.distributed_executor_backend = "ray"
            else:
                parallel_config.distributed_executor_backend = "uni"
132
133
134
135
        elif parallel_config.distributed_executor_backend == "mp":
            # FIXME(kunshang):
            # spawn needs calling `if __name__ == '__main__':``
            # fork is not supported for xpu start new process.
136
137
138
139
            if envs.VLLM_WORKER_MULTIPROC_METHOD != "spawn":
                os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
                logger.warning(
                    "Please use spawn as start method if you want to use mp.")
140
141
142
143
        elif (parallel_config.distributed_executor_backend != "ray"
              and parallel_config.distributed_executor_backend != "uni"
              and parallel_config.distributed_executor_backend
              != "external_launcher"):
144
145
146
147
148
            logger.warning(
                "%s is not supported on XPU, fallback to ray distributed"
                " executor backend.",
                parallel_config.distributed_executor_backend)
            parallel_config.distributed_executor_backend = "ray"
149

150
        if model_config and model_config.use_mla:
151
152
153
154
155
156
157
158
159
            logger.info(
                "MLA is enabled on a non-GPU platform; forcing chunked "
                "prefill and prefix caching to be disabled.")
            vllm_config.scheduler_config.enable_chunked_prefill = False
            vllm_config.scheduler_config.chunked_prefill_enabled = False
            vllm_config.scheduler_config.max_num_batched_tokens = max(
                vllm_config.scheduler_config.max_model_len,
                DEFAULT_MAX_NUM_BATCHED_TOKENS)

160
161
    @classmethod
    def is_pin_memory_available(cls):
162
        return True
163
164
165
166
167
168
169

    @classmethod
    def get_current_memory_usage(cls,
                                 device: Optional[torch.types.Device] = None
                                 ) -> float:
        torch.xpu.reset_peak_memory_stats(device)
        return torch.xpu.max_memory_allocated(device)
170
171
172
173

    @classmethod
    def device_support_bf16(cls) -> bool:
        device_name = cls.get_device_name().lower()
174
175
176
        if cls.is_client_gpu_a770():
            logger.warning("Intel Arc A770 have bfloat16 accuracy known issue,"
                           " fallback to float16")
177
178
            return False
        else:
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
            logger.info(
                "Device name %s supports bfloat16. Please file an issue "
                "if you encounter any accuracy problems with bfloat16.",
                device_name)
            return True

    @classmethod
    def is_data_center_gpu(cls) -> bool:
        device_name = cls.get_device_name().lower()
        return device_name.count("data center gpu") > 0

    @classmethod
    def is_client_gpu_a770(cls) -> bool:
        device_name = cls.get_device_name().lower()
        return device_name.count("a770") > 0
194
195
196
197

    @classmethod
    def get_device_communicator_cls(cls) -> str:
        return "vllm.distributed.device_communicators.xpu_communicator.XpuCommunicator"  # noqa
198
199
200
201
202
203
204

    @classmethod
    def supports_v1(cls, model_config: ModelConfig) -> bool:
        return True

    @classmethod
    def device_count(cls) -> int:
205
        return torch.xpu.device_count()