xpu.py 11.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import contextlib
5
import os
6
from typing import TYPE_CHECKING
7

8
9
import torch

10
11
12
13
14
# import custom ops, trigger op registration
import vllm_xpu_kernels._C  # noqa
import vllm_xpu_kernels._moe_C  # noqa
import vllm_xpu_kernels._xpu_C  # noqa

15
from vllm.logger import init_logger
16
from vllm.utils.torch_utils import supports_xpu_graph
17
from vllm.v1.attention.backends.registry import AttentionBackendEnum
18

19
from .interface import DeviceCapability, Platform, PlatformEnum
20

21
if TYPE_CHECKING:
22
    from vllm.config import VllmConfig
23
    from vllm.v1.attention.selector import AttentionSelectorConfig
24
25
26
else:
    VllmConfig = None

27
logger = init_logger(__name__)
28
29
30
31


class XPUPlatform(Platform):
    _enum = PlatformEnum.XPU
32
    device_name: str = "xpu"
33
    device_type: str = "xpu"
34
    dispatch_key: str = "XPU"
35
36
37
    # Intel XPU's device key is "GPU" for Ray.
    # see https://github.com/ray-project/ray/blob/6a5eb5865eeb9ccf058a79b44f107e327e360673/python/ray/_private/accelerators/intel_gpu.py#L20 # noqa: E501
    ray_device_key: str = "GPU"
38
    dist_backend: str = "xccl"  # xccl only
39
    device_control_env_var: str = "ZE_AFFINITY_MASK"
40

41
    @classmethod
42
43
44
45
    def import_kernels(cls) -> None:
        # Do not import vllm._C
        with contextlib.suppress(ImportError):
            import vllm._moe_C  # noqa: F401
46

47
    @classmethod
48
49
    def get_attn_backend_cls(
        cls,
50
        selected_backend: "AttentionBackendEnum",
51
        attn_selector_config: "AttentionSelectorConfig",
52
        num_heads: int | None = None,
53
    ) -> str:
54
55
56
57
58
59
60
61
        from vllm.v1.attention.backends.utils import set_kv_cache_layout

        set_kv_cache_layout("NHD")
        logger.info(
            "Setting VLLM_KV_CACHE_LAYOUT to 'NHD' for XPU; "
            "only NHD layout is supported by XPU attention kernels."
        )

62
        dtype = attn_selector_config.dtype
63
        if attn_selector_config.use_sparse:
64
65
            logger.info_once("Using XPU MLA Sparse backend.")
            return AttentionBackendEnum.XPU_MLA_SPARSE.get_path()
66
67
68
        if attn_selector_config.use_mla:
            logger.info_once("Using Triton MLA backend on V1 engine.")
            return AttentionBackendEnum.TRITON_MLA.get_path()
69
        if selected_backend == AttentionBackendEnum.TRITON_ATTN:
70
            logger.info_once("Using Triton backend.")
71
            return AttentionBackendEnum.TRITON_ATTN.get_path()
72
73
74
75
76
77
        elif dtype == torch.float32:
            logger.warning_once(
                "Flash Attention on XPU does not support float32 dtype. "
                "Falling back to Triton Attention backend."
            )
            return AttentionBackendEnum.TRITON_ATTN.get_path()
78
        elif selected_backend == AttentionBackendEnum.FLASH_ATTN:
79
            logger.info_once("Using Flash Attention backend.")
80
            return AttentionBackendEnum.FLASH_ATTN.get_path()
81
82
83
        elif selected_backend:
            raise ValueError(
                f"Invalid attention backend for {cls.device_name}, "
84
                f"with use_mla: {attn_selector_config.use_mla}"
85
            )
86

87
        logger.info("Using Flash Attention backend.")
88
        return AttentionBackendEnum.FLASH_ATTN.get_path()
89

90
91
92
93
    @classmethod
    def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]:
        return [
            AttentionBackendEnum.FLASH_ATTN,
94
            AttentionBackendEnum.TRITON_ATTN,
95
            AttentionBackendEnum.TORCH_SDPA,
96
97
98
99
100
101
102
        ]

    @classmethod
    def get_vit_attn_backend(
        cls,
        head_size: int,
        dtype: torch.dtype,
103
        backend: "AttentionBackendEnum | None" = None,
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
    ) -> "AttentionBackendEnum":
        if backend is not None:
            assert backend in cls.get_supported_vit_attn_backends(), (
                f"Backend {backend} is not supported for vit attention. "
                f"Supported backends are: "
                f"{cls.get_supported_vit_attn_backends()}."
            )
            logger.info_once(f"Using backend {backend} for vit attention")
            return backend

        logger.info_once(
            f"Using backend {AttentionBackendEnum.FLASH_ATTN} for vit attention"
        )
        return AttentionBackendEnum.FLASH_ATTN

119
120
121
122
123
124
125
    @classmethod
    def set_device(cls, device: torch.device) -> None:
        """
        Set the device for the current platform.
        """
        torch.xpu.set_device(device)

126
    @classmethod
127
    def get_device_capability(
128
129
        cls,
        device_id: int = 0,
130
    ) -> DeviceCapability | None:
131
132
133
        # capacity format differs from cuda's and will cause unexpected
        # failure, so use None directly
        return None
134

135
136
    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
137
        return torch.xpu.get_device_name(device_id)
138

139
140
    @classmethod
    def get_punica_wrapper(cls) -> str:
141
142
143
144
145
        xpu_use_triton_kernel = os.getenv("XPU_USE_TRITON_KERNEL", "0") == "1"
        if not xpu_use_triton_kernel:
            return "vllm.lora.punica_wrapper.punica_xpu.PunicaWrapperXPU"
        else:
            return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"
146

147
148
149
150
    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        device_props = torch.xpu.get_device_properties(device_id)
        return device_props.total_memory
151

152
153
    @classmethod
    def inference_mode(cls):
154
        return torch.no_grad()
155

156
157
158
159
    @classmethod
    def get_static_graph_wrapper_cls(cls) -> str:
        return "vllm.compilation.cuda_graph.CUDAGraphWrapper"

160
161
    @classmethod
    def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
162
        cache_config = vllm_config.cache_config
163
        model_config = vllm_config.model_config
164
        parallel_config = vllm_config.parallel_config
165
        # in V1(or with chunked prefill) block_size is 64
166
        if cache_config and not cache_config.user_specified_block_size:
167
            cache_config.block_size = 64
168

169
        # lazy import to avoid circular import
170
        from vllm.config import CompilationMode, CUDAGraphMode
171

172
        compilation_config = vllm_config.compilation_config
173
174
        if compilation_config.compile_sizes is None:
            compilation_config.compile_sizes = []
175

176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
        attention_config = vllm_config.attention_config
        if attention_config.backend is None:
            attention_config.backend = AttentionBackendEnum.FLASH_ATTN
        if not supports_xpu_graph():
            compilation_config.cudagraph_mode = CUDAGraphMode.NONE
            logger.warning(
                "XPU Graph is not supported in the current PyTorch version, "
                "disabling cudagraph_mode."
            )
        elif parallel_config.world_size_across_dp > 1:
            compilation_config.cudagraph_mode = CUDAGraphMode.NONE
            logger.warning(
                "XPU Graph doesn't support capture communication ops, "
                "disabling cudagraph_mode."
            )
        else:
            if (
                attention_config.backend == AttentionBackendEnum.FLASH_ATTN
                and compilation_config.cudagraph_mode
                not in {CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE}
            ):
                compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
                logger.warning(
                    "FMHA sycl-tla kernels cannot be captured with XPU graphs, "
                    "falling back to PIECEWISE graph mode on XPU platform."
                )
202

203
        if vllm_config.lora_config is not None:
204
            compilation_config.mode = CompilationMode.NONE
205
206
        # check and update parallel config
        parallel_config = vllm_config.parallel_config
207
208
209
210
        # Only override worker_cls if it's still the default "auto"
        # This allows custom workers (like vllm-omni workers) to be used on XPU
        if parallel_config.worker_cls == "auto":
            parallel_config.worker_cls = "vllm.v1.worker.xpu_worker.XPUWorker"
211
212
        if vllm_config.kv_transfer_config is not None:
            vllm_config.kv_transfer_config.enable_permute_local_kv = True
213

214
        if model_config and model_config.use_mla:
215
216
            logger.info(
                "MLA is enabled on a non-GPU platform; forcing chunked "
217
218
                "prefill and prefix caching to be disabled."
            )
219
220
            vllm_config.scheduler_config.enable_chunked_prefill = False
            vllm_config.scheduler_config.max_num_batched_tokens = max(
221
                vllm_config.model_config.max_model_len,
222
                vllm_config.scheduler_config.DEFAULT_MAX_NUM_BATCHED_TOKENS,
223
            )
224

225
226
227
228
229
230
        # In some cases, the internal memory type cache can misdetect GPU
        # memory as host memory, also leading to invalid memory access.
        # This cache can be disabled by setting UCX_MEMTYPE_CACHE=n.
        # ref. https://openucx.readthedocs.io/en/master/faq.html
        os.environ["UCX_MEMTYPE_CACHE"] = "n"

231
232
233
234
235
236
    @classmethod
    def update_block_size_for_backend(cls, vllm_config: "VllmConfig") -> None:
        # TODO: XPU still sets block_size in check_and_update_config.
        # Move that logic here so block_size is chosen by the backend.
        pass

237
238
239
240
    @classmethod
    def support_hybrid_kv_cache(cls) -> bool:
        return True

241
242
    @classmethod
    def support_static_graph_mode(cls) -> bool:
243
        return True
244

245
246
    @classmethod
    def is_pin_memory_available(cls):
247
        return True
248
249

    @classmethod
250
    def get_current_memory_usage(
251
        cls, device: torch.types.Device | None = None
252
    ) -> float:
253
254
        torch.xpu.reset_peak_memory_stats(device)
        return torch.xpu.max_memory_allocated(device)
255

256
257
    @classmethod
    def fp8_dtype(cls) -> torch.dtype:
258
        return torch.float8_e4m3fn
259

260
261
262
263
264
    @classmethod
    def is_data_center_gpu(cls) -> bool:
        device_name = cls.get_device_name().lower()
        return device_name.count("data center gpu") > 0

265
266
    @classmethod
    def get_device_communicator_cls(cls) -> str:
267
268
269
270
271
272
273
        from vllm.utils.torch_utils import supports_xccl

        if not supports_xccl():
            logger.warning(
                "xccl is not enabled in this torch build, communication"
                " is not available."
            )
274
        return "vllm.distributed.device_communicators.xpu_communicator.XpuCommunicator"  # noqa
275
276
277

    @classmethod
    def device_count(cls) -> int:
278
        return torch.xpu.device_count()
279
280

    @classmethod
281
282
    def check_if_supports_dtype(cls, dtype: torch.dtype):
        if dtype == torch.bfloat16:  # noqa: SIM102
283
284
285
286
287
288
            device_name = cls.get_device_name().lower()
            # client gpu a770
            if device_name.count("a770") > 0:
                raise ValueError(
                    "Intel Arc A770 have bfloat16 accuracy known issue. "
                    "You can use float16 instead by explicitly setting the "
289
290
                    "`dtype` flag in CLI, for example: --dtype=half."
                )
291
292
293
294

    @classmethod
    def opaque_attention_op(cls) -> bool:
        return True
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318

    @classmethod
    def insert_blocks_to_device(
        cls,
        src_cache: torch.Tensor,
        dst_cache: torch.Tensor,
        src_block_indices: torch.Tensor,
        dst_block_indices: torch.Tensor,
    ) -> None:
        """Copy blocks from src_cache to dst_cache on XPU."""
        _src_cache = src_cache[:, src_block_indices]
        dst_cache[:, dst_block_indices] = _src_cache.to(dst_cache.device)

    @classmethod
    def swap_out_blocks_to_host(
        cls,
        src_cache: torch.Tensor,
        dst_cache: torch.Tensor,
        src_block_indices: torch.Tensor,
        dst_block_indices: torch.Tensor,
    ) -> None:
        """Copy blocks from XPU to host (CPU)."""
        _src_cache = src_cache[:, src_block_indices]
        dst_cache[:, dst_block_indices] = _src_cache.cpu()
319
320
321
322

    @classmethod
    def num_compute_units(cls, device_id: int = 0) -> int:
        return torch.xpu.get_device_properties(device_id).max_compute_units