xpu.py 11 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import contextlib
5
import os
6
from typing import TYPE_CHECKING
7

8
9
import torch

10
11
12
13
14
# import custom ops, trigger op registration
import vllm_xpu_kernels._C  # noqa
import vllm_xpu_kernels._moe_C  # noqa
import vllm_xpu_kernels._xpu_C  # noqa

15
import vllm.envs as envs
16
from vllm.logger import init_logger
17
from vllm.utils.torch_utils import supports_xpu_graph
18
from vllm.v1.attention.backends.registry import AttentionBackendEnum
19

20
from .interface import DeviceCapability, Platform, PlatformEnum
21

22
if TYPE_CHECKING:
23
    from vllm.config import VllmConfig
24
    from vllm.v1.attention.selector import AttentionSelectorConfig
25
26
27
else:
    VllmConfig = None

28
logger = init_logger(__name__)
29
30
31
32


class XPUPlatform(Platform):
    _enum = PlatformEnum.XPU
33
    device_name: str = "xpu"
34
    device_type: str = "xpu"
35
    dispatch_key: str = "XPU"
36
37
38
    # Intel XPU's device key is "GPU" for Ray.
    # see https://github.com/ray-project/ray/blob/6a5eb5865eeb9ccf058a79b44f107e327e360673/python/ray/_private/accelerators/intel_gpu.py#L20 # noqa: E501
    ray_device_key: str = "GPU"
39
    dist_backend: str = "xccl"  # xccl only
40
    device_control_env_var: str = "ZE_AFFINITY_MASK"
41

42
    @classmethod
43
44
45
46
    def import_kernels(cls) -> None:
        # Do not import vllm._C
        with contextlib.suppress(ImportError):
            import vllm._moe_C  # noqa: F401
47

48
    @classmethod
49
50
    def get_attn_backend_cls(
        cls,
51
        selected_backend: "AttentionBackendEnum",
52
        attn_selector_config: "AttentionSelectorConfig",
53
        num_heads: int | None = None,
54
    ) -> str:
55
56
57
58
59
60
61
62
        from vllm.v1.attention.backends.utils import set_kv_cache_layout

        set_kv_cache_layout("NHD")
        logger.info(
            "Setting VLLM_KV_CACHE_LAYOUT to 'NHD' for XPU; "
            "only NHD layout is supported by XPU attention kernels."
        )

63
        dtype = attn_selector_config.dtype
64
        if attn_selector_config.use_sparse:
65
66
            logger.info_once("Using XPU MLA Sparse backend.")
            return AttentionBackendEnum.XPU_MLA_SPARSE.get_path()
67
68
69
        if attn_selector_config.use_mla:
            logger.info_once("Using Triton MLA backend on V1 engine.")
            return AttentionBackendEnum.TRITON_MLA.get_path()
70
        if selected_backend == AttentionBackendEnum.TRITON_ATTN:
71
            logger.info_once("Using Triton backend.")
72
            return AttentionBackendEnum.TRITON_ATTN.get_path()
73
74
75
76
77
78
        elif dtype == torch.float32:
            logger.warning_once(
                "Flash Attention on XPU does not support float32 dtype. "
                "Falling back to Triton Attention backend."
            )
            return AttentionBackendEnum.TRITON_ATTN.get_path()
79
        elif selected_backend == AttentionBackendEnum.FLASH_ATTN:
80
            logger.info_once("Using Flash Attention backend.")
81
            return AttentionBackendEnum.FLASH_ATTN.get_path()
82
83
84
        elif selected_backend:
            raise ValueError(
                f"Invalid attention backend for {cls.device_name}, "
85
                f"with use_mla: {attn_selector_config.use_mla}"
86
            )
87

88
        logger.info("Using Flash Attention backend.")
89
        return AttentionBackendEnum.FLASH_ATTN.get_path()
90

91
92
93
94
    @classmethod
    def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]:
        return [
            AttentionBackendEnum.FLASH_ATTN,
95
            AttentionBackendEnum.TRITON_ATTN,
96
            AttentionBackendEnum.TORCH_SDPA,
97
98
99
100
101
102
103
        ]

    @classmethod
    def get_vit_attn_backend(
        cls,
        head_size: int,
        dtype: torch.dtype,
104
        backend: "AttentionBackendEnum | None" = None,
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
    ) -> "AttentionBackendEnum":
        if backend is not None:
            assert backend in cls.get_supported_vit_attn_backends(), (
                f"Backend {backend} is not supported for vit attention. "
                f"Supported backends are: "
                f"{cls.get_supported_vit_attn_backends()}."
            )
            logger.info_once(f"Using backend {backend} for vit attention")
            return backend

        logger.info_once(
            f"Using backend {AttentionBackendEnum.FLASH_ATTN} for vit attention"
        )
        return AttentionBackendEnum.FLASH_ATTN

120
121
122
123
124
125
126
    @classmethod
    def set_device(cls, device: torch.device) -> None:
        """
        Set the device for the current platform.
        """
        torch.xpu.set_device(device)

127
    @classmethod
128
    def get_device_capability(
129
130
        cls,
        device_id: int = 0,
131
    ) -> DeviceCapability | None:
132
133
134
        # capacity format differs from cuda's and will cause unexpected
        # failure, so use None directly
        return None
135

136
137
    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
138
        return torch.xpu.get_device_name(device_id)
139

140
141
    @classmethod
    def get_punica_wrapper(cls) -> str:
142
143
144
145
146
        xpu_use_triton_kernel = os.getenv("XPU_USE_TRITON_KERNEL", "0") == "1"
        if not xpu_use_triton_kernel:
            return "vllm.lora.punica_wrapper.punica_xpu.PunicaWrapperXPU"
        else:
            return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"
147

148
149
150
151
    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        device_props = torch.xpu.get_device_properties(device_id)
        return device_props.total_memory
152

153
154
    @classmethod
    def inference_mode(cls):
155
        return torch.no_grad()
156

157
158
159
160
    @classmethod
    def get_static_graph_wrapper_cls(cls) -> str:
        return "vllm.compilation.cuda_graph.CUDAGraphWrapper"

161
162
    @classmethod
    def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
163
        parallel_config = vllm_config.parallel_config
164

165
        # lazy import to avoid circular import
166
        from vllm.config import CUDAGraphMode
167

168
        compilation_config = vllm_config.compilation_config
169
170
        if compilation_config.compile_sizes is None:
            compilation_config.compile_sizes = []
171

172
173
174
175
176
177
178
179
180
        attention_config = vllm_config.attention_config
        if attention_config.backend is None:
            attention_config.backend = AttentionBackendEnum.FLASH_ATTN
        if not supports_xpu_graph():
            compilation_config.cudagraph_mode = CUDAGraphMode.NONE
            logger.warning(
                "XPU Graph is not supported in the current PyTorch version, "
                "disabling cudagraph_mode."
            )
181
182
183
184
185
186
        elif not envs.VLLM_XPU_ENABLE_XPU_GRAPH:
            compilation_config.cudagraph_mode = CUDAGraphMode.NONE
            logger.warning(
                "XPU Graph is disabled by environment variable, "
                "please set VLLM_XPU_ENABLE_XPU_GRAPH=1 to enable it."
            )
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
        elif parallel_config.world_size_across_dp > 1:
            compilation_config.cudagraph_mode = CUDAGraphMode.NONE
            logger.warning(
                "XPU Graph doesn't support capture communication ops, "
                "disabling cudagraph_mode."
            )
        else:
            if (
                attention_config.backend == AttentionBackendEnum.FLASH_ATTN
                and compilation_config.cudagraph_mode
                not in {CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE}
            ):
                compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
                logger.warning(
                    "FMHA sycl-tla kernels cannot be captured with XPU graphs, "
                    "falling back to PIECEWISE graph mode on XPU platform."
                )
204
205
206

        # check and update parallel config
        parallel_config = vllm_config.parallel_config
207
208
209
210
        # Only override worker_cls if it's still the default "auto"
        # This allows custom workers (like vllm-omni workers) to be used on XPU
        if parallel_config.worker_cls == "auto":
            parallel_config.worker_cls = "vllm.v1.worker.xpu_worker.XPUWorker"
211
212
        if vllm_config.kv_transfer_config is not None:
            vllm_config.kv_transfer_config.enable_permute_local_kv = True
213

214
215
216
217
218
219
        # In some cases, the internal memory type cache can misdetect GPU
        # memory as host memory, also leading to invalid memory access.
        # This cache can be disabled by setting UCX_MEMTYPE_CACHE=n.
        # ref. https://openucx.readthedocs.io/en/master/faq.html
        os.environ["UCX_MEMTYPE_CACHE"] = "n"

220
221
222
223
    @classmethod
    def support_hybrid_kv_cache(cls) -> bool:
        return True

224
225
    @classmethod
    def support_static_graph_mode(cls) -> bool:
226
        return True
227

228
229
    @classmethod
    def is_pin_memory_available(cls):
230
        return True
231
232

    @classmethod
233
    def get_current_memory_usage(
234
        cls, device: torch.types.Device | None = None
235
    ) -> float:
236
        torch.xpu.empty_cache()
237
238
        torch.xpu.reset_peak_memory_stats(device)
        return torch.xpu.max_memory_allocated(device)
239

240
241
    @classmethod
    def fp8_dtype(cls) -> torch.dtype:
242
        return torch.float8_e4m3fn
243

244
245
246
247
248
    @classmethod
    def is_data_center_gpu(cls) -> bool:
        device_name = cls.get_device_name().lower()
        return device_name.count("data center gpu") > 0

249
250
    @classmethod
    def get_device_communicator_cls(cls) -> str:
251
252
253
254
255
256
257
        from vllm.utils.torch_utils import supports_xccl

        if not supports_xccl():
            logger.warning(
                "xccl is not enabled in this torch build, communication"
                " is not available."
            )
258
        return "vllm.distributed.device_communicators.xpu_communicator.XpuCommunicator"  # noqa
259
260
261

    @classmethod
    def device_count(cls) -> int:
262
        return torch.xpu.device_count()
263
264

    @classmethod
265
266
    def check_if_supports_dtype(cls, dtype: torch.dtype):
        if dtype == torch.bfloat16:  # noqa: SIM102
267
268
269
270
271
272
            device_name = cls.get_device_name().lower()
            # client gpu a770
            if device_name.count("a770") > 0:
                raise ValueError(
                    "Intel Arc A770 have bfloat16 accuracy known issue. "
                    "You can use float16 instead by explicitly setting the "
273
274
                    "`dtype` flag in CLI, for example: --dtype=half."
                )
275
276
277
278

    @classmethod
    def opaque_attention_op(cls) -> bool:
        return True
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302

    @classmethod
    def insert_blocks_to_device(
        cls,
        src_cache: torch.Tensor,
        dst_cache: torch.Tensor,
        src_block_indices: torch.Tensor,
        dst_block_indices: torch.Tensor,
    ) -> None:
        """Copy blocks from src_cache to dst_cache on XPU."""
        _src_cache = src_cache[:, src_block_indices]
        dst_cache[:, dst_block_indices] = _src_cache.to(dst_cache.device)

    @classmethod
    def swap_out_blocks_to_host(
        cls,
        src_cache: torch.Tensor,
        dst_cache: torch.Tensor,
        src_block_indices: torch.Tensor,
        dst_block_indices: torch.Tensor,
    ) -> None:
        """Copy blocks from XPU to host (CPU)."""
        _src_cache = src_cache[:, src_block_indices]
        dst_cache[:, dst_block_indices] = _src_cache.cpu()
303
304
305
306

    @classmethod
    def num_compute_units(cls, device_id: int = 0) -> int:
        return torch.xpu.get_device_properties(device_id).max_compute_units