xpu.py 14 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import contextlib
5
import os
6
from typing import TYPE_CHECKING
7

8
9
import torch

10
11
12
13
14
# import custom ops, trigger op registration
import vllm_xpu_kernels._C  # noqa
import vllm_xpu_kernels._moe_C  # noqa
import vllm_xpu_kernels._xpu_C  # noqa

15
import vllm.envs as envs
16
from vllm.logger import init_logger
17
from vllm.utils.torch_utils import supports_xpu_graph
18
from vllm.v1.attention.backends.registry import AttentionBackendEnum
19

20
from .interface import DeviceCapability, Platform, PlatformEnum
21

22
if TYPE_CHECKING:
23
    from vllm.config import VllmConfig
24
    from vllm.config.kernel import IrOpPriorityConfig
25
    from vllm.v1.attention.selector import AttentionSelectorConfig
26
27
28
else:
    VllmConfig = None

29
logger = init_logger(__name__)
30
31
32
33


class XPUPlatform(Platform):
    _enum = PlatformEnum.XPU
34
    device_name: str = "xpu"
35
    device_type: str = "xpu"
36
    dispatch_key: str = "XPU"
37
38
39
    # Intel XPU's device key is "GPU" for Ray.
    # see https://github.com/ray-project/ray/blob/6a5eb5865eeb9ccf058a79b44f107e327e360673/python/ray/_private/accelerators/intel_gpu.py#L20 # noqa: E501
    ray_device_key: str = "GPU"
40
    dist_backend: str = "xccl"  # xccl only
41
    device_control_env_var: str = "ZE_AFFINITY_MASK"
42

43
    @classmethod
44
45
46
47
    def import_kernels(cls) -> None:
        # Do not import vllm._C
        with contextlib.suppress(ImportError):
            import vllm._moe_C  # noqa: F401
48

49
    @classmethod
50
51
    def get_attn_backend_cls(
        cls,
52
        selected_backend: "AttentionBackendEnum",
53
        attn_selector_config: "AttentionSelectorConfig",
54
        num_heads: int | None = None,
55
    ) -> str:
56
57
58
59
60
61
62
63
        from vllm.v1.attention.backends.utils import set_kv_cache_layout

        set_kv_cache_layout("NHD")
        logger.info(
            "Setting VLLM_KV_CACHE_LAYOUT to 'NHD' for XPU; "
            "only NHD layout is supported by XPU attention kernels."
        )

64
65
66
67
68
69
        # TurboQuant KV cache: route directly to TQ backend
        kv_cache_dtype = attn_selector_config.kv_cache_dtype
        if kv_cache_dtype is not None and kv_cache_dtype.startswith("turboquant_"):
            logger.info_once("Using TurboQuant attention backend.")
            return AttentionBackendEnum.TURBOQUANT.get_path()

70
        dtype = attn_selector_config.dtype
71
        if attn_selector_config.use_sparse:
72
73
            logger.info_once("Using XPU MLA Sparse backend.")
            return AttentionBackendEnum.XPU_MLA_SPARSE.get_path()
74
75
76
        if attn_selector_config.use_mla:
            logger.info_once("Using Triton MLA backend on V1 engine.")
            return AttentionBackendEnum.TRITON_MLA.get_path()
77
        if selected_backend == AttentionBackendEnum.TRITON_ATTN:
78
            logger.info_once("Using Triton backend.")
79
            return AttentionBackendEnum.TRITON_ATTN.get_path()
80
81
82
83
84
85
        elif dtype == torch.float32:
            logger.warning_once(
                "Flash Attention on XPU does not support float32 dtype. "
                "Falling back to Triton Attention backend."
            )
            return AttentionBackendEnum.TRITON_ATTN.get_path()
86
        elif selected_backend == AttentionBackendEnum.FLASH_ATTN:
87
            logger.info_once("Using Flash Attention backend.")
88
            return AttentionBackendEnum.FLASH_ATTN.get_path()
89
90
91
        elif selected_backend:
            raise ValueError(
                f"Invalid attention backend for {cls.device_name}, "
92
                f"with use_mla: {attn_selector_config.use_mla}"
93
            )
94

95
        logger.info("Using Flash Attention backend.")
96
        return AttentionBackendEnum.FLASH_ATTN.get_path()
97

98
99
100
101
    @classmethod
    def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]:
        return [
            AttentionBackendEnum.FLASH_ATTN,
102
            AttentionBackendEnum.TRITON_ATTN,
103
            AttentionBackendEnum.TORCH_SDPA,
104
105
106
107
108
109
110
        ]

    @classmethod
    def get_vit_attn_backend(
        cls,
        head_size: int,
        dtype: torch.dtype,
111
        backend: "AttentionBackendEnum | None" = None,
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
    ) -> "AttentionBackendEnum":
        if backend is not None:
            assert backend in cls.get_supported_vit_attn_backends(), (
                f"Backend {backend} is not supported for vit attention. "
                f"Supported backends are: "
                f"{cls.get_supported_vit_attn_backends()}."
            )
            logger.info_once(f"Using backend {backend} for vit attention")
            return backend

        logger.info_once(
            f"Using backend {AttentionBackendEnum.FLASH_ATTN} for vit attention"
        )
        return AttentionBackendEnum.FLASH_ATTN

127
128
129
130
131
132
133
    @classmethod
    def set_device(cls, device: torch.device) -> None:
        """
        Set the device for the current platform.
        """
        torch.xpu.set_device(device)

134
135
136
137
    @classmethod
    def manual_seed_all(cls, seed: int) -> None:
        torch.xpu.manual_seed_all(seed)

138
    @classmethod
139
    def get_device_capability(
140
141
        cls,
        device_id: int = 0,
142
    ) -> DeviceCapability | None:
143
144
145
        # capacity format differs from cuda's and will cause unexpected
        # failure, so use None directly
        return None
146

147
148
    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
149
        return torch.xpu.get_device_name(device_id)
150

151
152
    @classmethod
    def get_punica_wrapper(cls) -> str:
153
154
155
156
157
        xpu_use_triton_kernel = os.getenv("XPU_USE_TRITON_KERNEL", "0") == "1"
        if not xpu_use_triton_kernel:
            return "vllm.lora.punica_wrapper.punica_xpu.PunicaWrapperXPU"
        else:
            return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"
158

159
160
161
162
    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        device_props = torch.xpu.get_device_properties(device_id)
        return device_props.total_memory
163

164
165
    @classmethod
    def inference_mode(cls):
166
        return torch.no_grad()
167

168
169
170
171
    @classmethod
    def get_static_graph_wrapper_cls(cls) -> str:
        return "vllm.compilation.cuda_graph.CUDAGraphWrapper"

172
173
    @classmethod
    def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
174
        parallel_config = vllm_config.parallel_config
175

176
        # lazy import to avoid circular import
177
        from vllm.config import CUDAGraphMode
178

179
        compilation_config = vllm_config.compilation_config
180
181
        if compilation_config.compile_sizes is None:
            compilation_config.compile_sizes = []
182

183
184
185
186
187
188
189
190
191
        attention_config = vllm_config.attention_config
        if attention_config.backend is None:
            attention_config.backend = AttentionBackendEnum.FLASH_ATTN
        if not supports_xpu_graph():
            compilation_config.cudagraph_mode = CUDAGraphMode.NONE
            logger.warning(
                "XPU Graph is not supported in the current PyTorch version, "
                "disabling cudagraph_mode."
            )
192
193
194
195
196
197
        elif not envs.VLLM_XPU_ENABLE_XPU_GRAPH:
            compilation_config.cudagraph_mode = CUDAGraphMode.NONE
            logger.warning(
                "XPU Graph is disabled by environment variable, "
                "please set VLLM_XPU_ENABLE_XPU_GRAPH=1 to enable it."
            )
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
        elif parallel_config.world_size_across_dp > 1:
            compilation_config.cudagraph_mode = CUDAGraphMode.NONE
            logger.warning(
                "XPU Graph doesn't support capture communication ops, "
                "disabling cudagraph_mode."
            )
        else:
            if (
                attention_config.backend == AttentionBackendEnum.FLASH_ATTN
                and compilation_config.cudagraph_mode
                not in {CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE}
            ):
                compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
                logger.warning(
                    "FMHA sycl-tla kernels cannot be captured with XPU graphs, "
                    "falling back to PIECEWISE graph mode on XPU platform."
                )
215
216
217

        # check and update parallel config
        parallel_config = vllm_config.parallel_config
218
219
220
221
        # Only override worker_cls if it's still the default "auto"
        # This allows custom workers (like vllm-omni workers) to be used on XPU
        if parallel_config.worker_cls == "auto":
            parallel_config.worker_cls = "vllm.v1.worker.xpu_worker.XPUWorker"
222
223
        if vllm_config.kv_transfer_config is not None:
            vllm_config.kv_transfer_config.enable_permute_local_kv = True
224

225
226
227
228
229
230
        # In some cases, the internal memory type cache can misdetect GPU
        # memory as host memory, also leading to invalid memory access.
        # This cache can be disabled by setting UCX_MEMTYPE_CACHE=n.
        # ref. https://openucx.readthedocs.io/en/master/faq.html
        os.environ["UCX_MEMTYPE_CACHE"] = "n"

231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
    @classmethod
    def update_block_size_for_backend(cls, vllm_config: "VllmConfig") -> None:
        super().update_block_size_for_backend(vllm_config)
        from vllm.config.vllm import get_layers_from_vllm_config
        from vllm.model_executor.layers.attention_layer_base import (
            AttentionLayerBase,
        )
        from vllm.utils.math_utils import cdiv

        cache_config = vllm_config.cache_config
        # special fix for GDN since kernel only supports block size dividable by 64
        attn_layers = get_layers_from_vllm_config(
            vllm_config,
            AttentionLayerBase,  # type: ignore[type-abstract]
        )

        kernel_block_size = None
        for layer in attn_layers.values():
            b = layer.get_attn_backend()
            if b.get_name() == "GDN_ATTN":
                kernel_block_size = 64
                break

        if kernel_block_size is None:
            return
        new_block_size = (
            cdiv(cache_config.block_size, kernel_block_size) * kernel_block_size
        )
        if new_block_size == cache_config.block_size:
            return

        if cache_config.mamba_cache_mode == "align":
            cache_config.mamba_block_size = new_block_size
        original_mamba_page_size_padded = cache_config.mamba_page_size_padded
        if cache_config.mamba_page_size_padded is not None:
            attn_page_size_1_token = (
                cache_config.mamba_page_size_padded // cache_config.block_size
            )
            cache_config.mamba_page_size_padded = (
                new_block_size * attn_page_size_1_token
            )
        cache_config.block_size = new_block_size
        logger.info(
            "[XPU]Setting attention block size to %d tokens to ensure multiple of %d, "
            "set mamba_page_size_padded to %d bytes accordingly, before was %d bytes.",
            new_block_size,
            kernel_block_size,
            cache_config.mamba_page_size_padded,
            original_mamba_page_size_padded,
        )

282
283
284
285
    @classmethod
    def support_hybrid_kv_cache(cls) -> bool:
        return True

286
287
    @classmethod
    def support_static_graph_mode(cls) -> bool:
288
        return True
289

290
291
    @classmethod
    def is_pin_memory_available(cls):
292
        return True
293
294

    @classmethod
295
    def get_current_memory_usage(
296
        cls, device: torch.types.Device | None = None
297
    ) -> float:
298
        torch.xpu.empty_cache()
299
300
        torch.xpu.reset_peak_memory_stats(device)
        return torch.xpu.max_memory_allocated(device)
301

302
303
    @classmethod
    def fp8_dtype(cls) -> torch.dtype:
304
        return torch.float8_e4m3fn
305

306
307
308
309
310
    @classmethod
    def is_data_center_gpu(cls) -> bool:
        device_name = cls.get_device_name().lower()
        return device_name.count("data center gpu") > 0

311
312
    @classmethod
    def get_device_communicator_cls(cls) -> str:
313
314
315
316
317
318
319
        from vllm.utils.torch_utils import supports_xccl

        if not supports_xccl():
            logger.warning(
                "xccl is not enabled in this torch build, communication"
                " is not available."
            )
320
        return "vllm.distributed.device_communicators.xpu_communicator.XpuCommunicator"  # noqa
321

322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
    @classmethod
    def get_default_ir_op_priority(
        cls, vllm_config: "VllmConfig"
    ) -> "IrOpPriorityConfig":
        from vllm.config.compilation import CompilationMode
        from vllm.config.kernel import IrOpPriorityConfig

        # Native used by default when compiling,
        # use fused kernels where available when no codegen
        cc = vllm_config.compilation_config
        using_inductor = cc.backend == "inductor" and cc.mode != CompilationMode.NONE
        default = ["native"] if using_inductor else ["xpu_kernels", "native"]

        return IrOpPriorityConfig.with_default(default)

337
338
    @classmethod
    def device_count(cls) -> int:
339
        return torch.xpu.device_count()
340
341

    @classmethod
342
343
    def check_if_supports_dtype(cls, dtype: torch.dtype):
        if dtype == torch.bfloat16:  # noqa: SIM102
344
345
346
347
348
349
            device_name = cls.get_device_name().lower()
            # client gpu a770
            if device_name.count("a770") > 0:
                raise ValueError(
                    "Intel Arc A770 have bfloat16 accuracy known issue. "
                    "You can use float16 instead by explicitly setting the "
350
351
                    "`dtype` flag in CLI, for example: --dtype=half."
                )
352
353
354
355

    @classmethod
    def opaque_attention_op(cls) -> bool:
        return True
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379

    @classmethod
    def insert_blocks_to_device(
        cls,
        src_cache: torch.Tensor,
        dst_cache: torch.Tensor,
        src_block_indices: torch.Tensor,
        dst_block_indices: torch.Tensor,
    ) -> None:
        """Copy blocks from src_cache to dst_cache on XPU."""
        _src_cache = src_cache[:, src_block_indices]
        dst_cache[:, dst_block_indices] = _src_cache.to(dst_cache.device)

    @classmethod
    def swap_out_blocks_to_host(
        cls,
        src_cache: torch.Tensor,
        dst_cache: torch.Tensor,
        src_block_indices: torch.Tensor,
        dst_block_indices: torch.Tensor,
    ) -> None:
        """Copy blocks from XPU to host (CPU)."""
        _src_cache = src_cache[:, src_block_indices]
        dst_cache[:, dst_block_indices] = _src_cache.cpu()
380
381
382
383

    @classmethod
    def num_compute_units(cls, device_id: int = 0) -> int:
        return torch.xpu.get_device_properties(device_id).max_compute_units